-
Notifications
You must be signed in to change notification settings - Fork 0
/
noncontinual_classifier.py
27 lines (20 loc) · 1.07 KB
/
noncontinual_classifier.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
from continual_classifier import ContinualClassifier
from utils import estimate_fisher_diagonal
from keras.losses import categorical_crossentropy
import numpy as np
from keras.models import Model
from keras.layers import Lambda
import tensorflow as tf
import keras.backend as K
from sklearn.utils import shuffle
from tqdm import tqdm
class DeepClassifier(ContinualClassifier):
def __init__(self, optimizer='sgd', loss='categorical_crossentropy', metrics=['accuracy'], singleheaded_classes=None, model={'layers':3, 'units':200,'dropout':0,'activation':'relu'}):
super().__init__(optimizer,loss,metrics,singleheaded_classes,model)
def save_model(self, filename):
pass
def load_model(self, filename):
pass
def task_fit_method(self, X, Y, model, new_task, batch_size, epochs, validation_data=None, verbose=2):
model.compile(loss=self.loss,optimizer=self.optimizer,metrics=['accuracy'])
model.fit(X,Y, batch_size=batch_size, epochs=epochs, validation_data = validation_data, verbose=verbose, shuffle=True)