In [15]:
import tensorflow as tf
from tensorflow.keras.models import Sequential , load_model
from tensorflow.keras.layers import Dense, Dropout, Activation, Flatten, Conv2D, MaxPooling2D
from utils import get_data_loaders
import torch
import numpy as np
import matplotlib.pyplot as plt
from keras.callbacks import ModelCheckpoint


## create CNN model

In [8]:
model = Sequential ()
model.add(Conv2D (64, (3,3), input_shape = (32,32,3)))
model.add(Activation("relu"))
model.add(MaxPooling2D(pool_size=(2,2)))

model.add(Conv2D(64, (3,3)))
model.add(Activation("relu"))
model.add(MaxPooling2D(pool_size=(2,2)))
          
model.add(Flatten())
model.add(Dense(256))     
model.add(Activation( 'softmax'))   
model.compile(loss='mean_squared_error', optimizer = "adam", metrics=['accuracy'])

## create data loader

In [9]:

full_dataloaders, _ = get_data_loaders(
     filenames={
            'train': './data/12000_train_mnistmnistmsvhnsynusps.npz',
            'test': './data/12000_test_mnistmnistmsvhnsynusps.npz',
      },
       batch_size= 6000
    )

datafiles to read:  {'train': './data/12000_train_mnistmnistmsvhnsynusps.npz', 'test': './data/12000_test_mnistmnistmsvhnsynusps.npz'}
reading ./data/12000_train_mnistmnistmsvhnsynusps.npz, number of samples: 60000
reading ./data/12000_test_mnistmnistmsvhnsynusps.npz, number of samples: 21600
reading ./data/12000_test_mnistmnistmsvhnsynusps.npz, number of samples: 21600


## loade data and train CNN model 

In [10]:

(images, features, domain_labels, digit_labels) =next(iter(full_dataloaders['train']))

print('images shape: ', images.shape)
print('features shape: ', features.shape)
print('domain labels freq: ', torch.unique(domain_labels, return_counts=True))
print('digit labels freq: ', torch.unique(digit_labels, return_counts=True))

#changing indexes (3,32,32) to (32,32,3)
img = torch.einsum('zkij->zijk',images)
model.fit(img.numpy(),features.numpy(),epochs=10, batch_size=32, validation_split=0.1)
       

images shape:  torch.Size([6000, 3, 32, 32])
features shape:  torch.Size([6000, 256])
domain labels freq:  (tensor([0, 1, 2, 3, 4]), tensor([1224, 1217, 1180, 1204, 1175]))
digit labels freq:  (tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), tensor([666, 783, 658, 603, 573, 496, 580, 560, 555, 526]))
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<keras.src.callbacks.History at 0x16dc8a4a0>

In [16]:
# Save the trained model
model.save('my_model.h5')

# Load the saved model
loaded_model = load_model('my_model.h5')
