In [None]:
import torch
import torchvision
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
from torchvision.models import resnet34
import matplotlib.pyplot as plt
from torch import nn
import time, os, copy, numpy as np
from tqdm import tqdm, time, copy

In [None]:
root = './sample_data/'
batch_size = 64
lr = 1e-2

In [None]:
mean=[0.469,0.526,0.575]
std=[0.155,0.147,0.128]
data_transformation = transforms.Compose([transforms.Resize(224),\
                                          transforms.RandomHorizontalFlip(),\
                                          transforms.ToTensor(),\
                                          transforms.Normalize(mean=mean, std=std)])

In [None]:
cloudClfDataset = {mode: datasets.ImageFolder(root=root+mode,\
                                              transform=data_transformation)
                   for mode in ['train', 'valid']}

In [None]:
dataset_sizes = {mode: len(cloudClfDataset[mode]) for mode in ['train', 'valid']}

In [None]:
cloudClfDataset_loader = {mode: DataLoader(dataset=cloudClfDataset[mode],\
                                            batch_size=batch_size,
                                            shuffle=True, num_workers=0)
                          for mode in ['train', 'valid']}

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
cloudModel = resnet34(True)
for param in cloudModel.parameters():
    param.required_grad = False
cloudModel.fc = nn.Linear(512,3)

In [None]:
cloudModel = cloudModel.to(device)

In [None]:
criterion = nn.CrossEntropyLoss().to(device)

In [None]:
optimiser = torch.optim.SGD(params=model_fit.fc.parameters(), lr=lr,momentum=0)

In [None]:
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimiser, step_size=5, gamma=0.001)

In [None]:
def train_model(model, dataloader, criterion, optimiser, scheduler, device, dataset_size, num_epochs=10):
    try:
        since = time.time()

        best_model_wts = copy.deepcopy(model.state_dict())
        best_acc = 0.0

        for epoch in range(num_epochs):
            print(f'Epoch {epoch+1}/{num_epochs}')
            
            for phase in ['train', 'valid']:
                if phase == 'train':
                    scheduler.step()
                    model.train()
                else:
                    model.eval()

                running_loss = 0.0
                running_corrects = 0

                for images, labels in dataloader[phase]:
                    images = images.to(device)
                    labels = labels.to(device)

                    optimiser.zero_grad()

                    with torch.set_grad_enabled(phase=='train'):
                        outputs = model(images)
                        _, preds = torch.max(outputs, 1)
                        loss = criterion(outputs, labels)

                        if phase == 'train':
                            loss.backward()
                            optimiser.step()
                    temp1 = loss.item() * images.size(0)
                    temp2 = torch.sum(preds==labels.data)
                    running_loss += temp1
                    running_corrects += temp2
                    print('Running loss:{:.2f} Acc:{:}   '.format(temp1/images.size(0), temp2), end='\r')
                epoch_loss = running_loss/dataset_size[phase]
                epoch_acc = running_corrects.double()/dataset_size[phase]
                print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss,
                                                          epoch_acc*100))

                if phase == 'valid' and epoch_acc > best_acc:
                    best_acc = epoch_acc
                    resume_model_wts = copy.deepcopy(model.state_dict())
                    best_model_wts = copy.deepcopy(model.state_dict())
            print('--' * 20)
            print()

        time_elapsed = time.time() - since
        print('Training completed in {:.0f}m {:.0f}s'.format(
                time_elapsed//60, time_elapsed%60))
        print('Best valid Acc: {:4f}'.format(best_acc))

        model.load_state_dict(best_model_wts)
    except KeyboardInterrupt:
        model.load_state_dict(best_model_wts)
        time_elapsed = time.time() - since
        print('Training completed in {:.0f}m {:.0f}s'.format(
                time_elapsed//60, time_elapsed%60))
        print('Best valid Acc: {:4f}'.format(best_acc))
        return model
    return model

In [None]:
model_fit = train_model(model=model_fit,\
                        dataloader=cloudClfDataset_loader,\
                        criterion=criterion,\
                        optimiser=optimiser,\
                        scheduler=lr_scheduler,\
                        device=device,\
                        dataset_size=dataset_sizes,\
                        num_epochs=20)

In [None]:
torch.save(model_fit.state_dict(), '3sky_states_customresnet.pth')