Skip to content

Commit

Permalink
Initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
Satya Kesav committed Mar 21, 2019
1 parent e9dbf66 commit ea38c65
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 6 deletions.
2 changes: 1 addition & 1 deletion autokeras/supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ def evaluate(self, x_test, y_test):
class SingleModelSupervised(Supervised):
"""The base class for all supervised tasks that do not use neural architecture search.
Inheirits from Supervised class.
Inherits from Supervised class.
Attributes:
verbose: A boolean value indicating the verbosity mode.
Expand Down
1 change: 0 additions & 1 deletion autokeras/text/pretrained_bert/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,4 +637,3 @@ def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=No
return loss
else:
return logits

53 changes: 49 additions & 4 deletions autokeras/text/text_supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import ABC
from abc import ABC, abstractmethod

import numpy as np
import os
Expand All @@ -33,7 +33,53 @@
from torch.utils.data import TensorDataset, DataLoader, SequentialSampler


class TextClassifier(SingleModelSupervised, ABC):
class TextSupervised(SingleModelSupervised, ABC):

def __init__(self, verbose, **kwargs):
super().__init__(verbose=verbose, **kwargs)
self.device = get_device()

# BERT specific
self.bert_model = 'bert-base-uncased'
self.tokenizer = BertTokenizer.from_pretrained(self.bert_model, do_lower_case=True)

# Labels/classes
self.num_labels = None

@abstractmethod
def fit(self, x, y, time_limit=None):
pass

@abstractmethod
def predict(self, x_test):
pass

@property
@abstractmethod
def metric(self):
pass

@property
@abstractmethod
def loss(self):
pass

@abstractmethod
def preprocess(self, x):
pass

def transform_y(self, y):
pass

def inverse_transform_y(self, output):
return np.argmax(output, axis=1)


class TextRegressor(TextSupervised):
pass


class TextClassifier(TextSupervised):
"""A TextClassifier class based on Google AI's BERT model.
Attributes:
Expand All @@ -51,9 +97,8 @@ def __init__(self, verbose, **kwargs):
Args:
verbose: Mode of verbosity.
"""
super().__init__(**kwargs)
super().__init__(verbose=verbose, **kwargs)
self.device = get_device()
self.verbose = verbose

# BERT specific
self.bert_model = 'bert-base-uncased'
Expand Down

0 comments on commit ea38c65

Please sign in to comment.