You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I was trying to implement CapsuleNet for classifying digits. All the images are RGB images and resize to 32 X 32 and dataset has 10 classification output.
X_train_all.shape: (745, 32, 32, 3)
y_train_all.shape: (745, 10)
First, define the CapsNet model which takes the following parameters.
# define model
model = CapsNet()
Ok, I think this chunk of information only enough here. However, then I define a function for data augmentation and both training the model. This function will iteratively train on the shuffling fold of the data set, defined below.
kfold = KFold(n_splits=10, shuffle=True, random_state=42)
for train, val in kfold.split(X_train_all, y_train_all):
print ('Fold: ', Fold)
X_train = X_train_all[train]
X_val = X_train_all[val]
y_train = y_train_all[train]
y_val = y_train_all[val]
# train the model with data augmentation
train()
Fold = Fold + 1
The train function above, defined by you as follows:
def train(model, data, epoch_size_frac=1.0):
# unpacking the data
(x_train, y_train), (x_val, y_val) = data
# compile the model
model.compile()
def train_generator(x, y, batch_size, shift_fraction=0.):
train_datagen = ImageDataGenerator
# Training with data augmentation.
model.fit_generator
return model
Now, when I start training, I get following error:
TypeError Traceback (most recent call last)
<ipython-input-67-a8ca211d12a2> in <module>()
19
20
---> 21 train(model = model, data = ((X_train, y_train), (X_val, y_val)), epoch_size_frac = 0.5)
22
23
TypeError: 'numpy.ndarray' object is not callable
The text was updated successfully, but these errors were encountered:
I was trying to implement
CapsuleNet
for classifying digits. All the images are RGB images and resize to32 X 32
and dataset has 10 classification output.First, define the CapsNet model which takes the following parameters.
Ok, I think this chunk of information only enough here. However, then I define a function for data augmentation and both training the model. This function will iteratively train on the shuffling fold of the data set, defined below.
The
train
function above, defined by you as follows:Now, when I start training, I get following error:
The text was updated successfully, but these errors were encountered: