Skip to content
Permalink
Browse files

Initial commit

  • Loading branch information
Satya Kesav
Satya Kesav committed Mar 21, 2019
1 parent e9dbf66 commit ea38c654e50b39b13bf327c9705b4ec07f98c23c
Showing with 50 additions and 6 deletions.
  1. +1 −1 autokeras/supervised.py
  2. +0 −1 autokeras/text/pretrained_bert/modeling.py
  3. +49 −4 autokeras/text/text_supervised.py
@@ -286,7 +286,7 @@ def evaluate(self, x_test, y_test):
class SingleModelSupervised(Supervised):
"""The base class for all supervised tasks that do not use neural architecture search.
Inheirits from Supervised class.
Inherits from Supervised class.
Attributes:
verbose: A boolean value indicating the verbosity mode.
@@ -637,4 +637,3 @@ def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=No
return loss
else:
return logits

@@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import ABC
from abc import ABC, abstractmethod

import numpy as np
import os
@@ -33,7 +33,53 @@
from torch.utils.data import TensorDataset, DataLoader, SequentialSampler


class TextClassifier(SingleModelSupervised, ABC):
class TextSupervised(SingleModelSupervised, ABC):

def __init__(self, verbose, **kwargs):
super().__init__(verbose=verbose, **kwargs)
self.device = get_device()

# BERT specific
self.bert_model = 'bert-base-uncased'
self.tokenizer = BertTokenizer.from_pretrained(self.bert_model, do_lower_case=True)

# Labels/classes
self.num_labels = None

@abstractmethod
def fit(self, x, y, time_limit=None):
pass

@abstractmethod
def predict(self, x_test):
pass

@property
@abstractmethod
def metric(self):
pass

@property
@abstractmethod
def loss(self):
pass

@abstractmethod
def preprocess(self, x):
pass

def transform_y(self, y):
pass

def inverse_transform_y(self, output):
return np.argmax(output, axis=1)


class TextRegressor(TextSupervised):
pass


class TextClassifier(TextSupervised):
"""A TextClassifier class based on Google AI's BERT model.
Attributes:
@@ -51,9 +97,8 @@ def __init__(self, verbose, **kwargs):
Args:
verbose: Mode of verbosity.
"""
super().__init__(**kwargs)
super().__init__(verbose=verbose, **kwargs)
self.device = get_device()
self.verbose = verbose

# BERT specific
self.bert_model = 'bert-base-uncased'

0 comments on commit ea38c65

Please sign in to comment.
You can’t perform that action at this time.