In [18]:
import torch
import torchvision
import os
from tqdm.auto import tqdm
from NNlib import create_dataloaders, get_initial_model, train

In [2]:
# define constants
device = "cuda" if torch.cuda.is_available() else "cpu"
lr = 0.001
batch_size = 32
num_workers = os.cpu_count()
num_classes = 100
num_test_img = 6667

NN_train_dir = 'data/NN_train_all/'
NN_validate_dir = 'data/NN_validate_all/'
NN_test_dir = 'data/NN_test_all/'
test_data_dir = 'test_filt.csv'
model_temp_file = 'model_temp'
model_file = 'model_4_3_V6'

In [3]:
# define data transformation
data_transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor()
])

# create dataset from pre-processed images
train_dataloader, test_dataloader, class_names = create_dataloaders(NN_train_dir,NN_validate_dir,data_transform,batch_size,num_workers)

In [4]:
# initial training model
model_0 = get_initial_model(num_classes, device)
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model_0.parameters(), lr)



In [None]:
# lightly train the classifier while feature layers are frozen
model_0_results = train(model=model_0, train_dataloader=train_dataloader, test_dataloader=test_dataloader,
                       optimizer=optimizer, loss_fn=loss_fn, epochs=7, device=device)

In [None]:
torch.save(model_0.state_dict(), model_temp_file)

In [6]:
model_0.load_state_dict(torch.load(model_temp_file))

<All keys matched successfully>

In [6]:
# unfreeze feature layers
for param in model_0.features.parameters():
    param.requires_grad = True
    
lr = 0.00001

In [7]:
# train the entire model with small learning rate
model_0_results = train(model=model_0, train_dataloader=train_dataloader, test_dataloader=test_dataloader,
                       optimizer=optimizer, loss_fn=loss_fn, epochs=5, device=device)

 33%|███▎      | 1/3 [26:40<53:21, 1600.74s/it]

Epoch: 1 | train_loss: 0.7657 | train_acc: 0.8328 | test_loss: 0.8888 | test_acc: 0.8297


 67%|██████▋   | 2/3 [53:06<26:31, 1591.91s/it]

Epoch: 2 | train_loss: 0.5783 | train_acc: 0.8656 | test_loss: 0.9324 | test_acc: 0.8246


100%|██████████| 3/3 [1:19:27<00:00, 1589.14s/it]

Epoch: 3 | train_loss: 0.4550 | train_acc: 0.8883 | test_loss: 0.9037 | test_acc: 0.8315





In [8]:
torch.save(model_0.state_dict(), model_file)