In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
from torch.utils.data import random_split, DataLoader, TensorDataset
from torchvision.datasets import ImageFolder
import torchvision.transforms as tt 
from copy import copy

In [None]:
from callbacks import EarlyStopping
from CNN_Resnet import CNNModel
import os

In [None]:
train_tfms = tt.Compose([tt.RandomCrop(200, padding=25, padding_mode='reflect'),
                         tt.RandomHorizontalFlip(),
                         tt.RandomRotation(10),
                         tt.RandomPerspective(distortion_scale=0.2),
                         tt.ToTensor()
                        ])
test_tfms = tt.Compose([tt.ToTensor()])

In [None]:
data_dir = "."
dataset = ImageFolder(data_dir+'/asl_alphabet_train/asl_alphabet_train')

In [None]:
val_size = int(0.15 * len(dataset))
train_size = len(dataset) - val_size

train_ds, val_ds = random_split(dataset, [train_size, val_size])
len(train_ds), len(val_ds)

In [None]:
train_ds.dataset = copy(dataset)
train_ds.dataset.transform = train_tfms
val_ds.dataset.transform = test_tfms

In [None]:
batch_size = 64
train_dl = DataLoader(train_ds, batch_size, shuffle=True)

In [None]:
val_dl = DataLoader(val_ds, batch_size*2)

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
hyperpara = {
    'out_features': 29,
    'lr': 1e-3,
}

In [None]:
tmodel = CNNModel(**hyperpara)

In [None]:
tmodel.train_loader = train_dl
tmodel.valid_loader = val_dl

MODEL_PATH = "models/"
MODEL_NAME = "ResNet"
EPOCHS = 6

es = EarlyStopping(
    monitor="valid_loss",
    model_path=os.path.join(MODEL_PATH, MODEL_NAME + "_early.bin"),
    patience=3, mode="min")

In [None]:
tmodel.fit(device=device, epochs=EPOCHS, callbacks=[es])

In [None]:
tmodel.save_model("models/final_resnet.pth")

In [1]:
from CNN_Efficient import CNNModel_Eff

In [2]:
efmodel = CNNModel_Eff(**hyperpara)

Downloading: "https://download.pytorch.org/models/efficientnet_b0_rwightman-3dd342df.pth" to /Users/gauthamsreekumar/.cache/torch/hub/checkpoints/efficientnet_b0_rwightman-3dd342df.pth
100%|██████████| 20.5M/20.5M [02:51<00:00, 125kB/s] 


In [None]:
efmodel.train_loader = train_dl
efmodel.valid_loader = val_dl

MODEL_PATH = "models/"
MODEL_NAME = "Efficient"
EPOCHS = 6

es = EarlyStopping(
    monitor="valid_loss",
    model_path=os.path.join(MODEL_PATH, MODEL_NAME + "_early.bin"),
    patience=3, mode="min")

In [None]:
efmodel.fit(device=device, epochs=EPOCHS, callbacks=[es])

In [None]:
tmodel.save_model("models/final_eff.pth")