In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable

In [None]:
import torchvision
from torchvision import datasets, models, transforms

In [None]:
import os
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import patches, patheffects
import pandas as pd
import json
import cv2
from PIL import ImageDraw, ImageFont
from collections import namedtuple, OrderedDict

In [None]:
from session import *
from LR_Schedule.cos_anneal import CosAnneal
from LR_Schedule.lr_find import lr_find
from callbacks import *
from validation import *
import Datasets.ImageData as ImageData
from Transforms.ImageTransforms import *

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
torch.cuda.is_available(), torch.cuda.get_device_name(0), torch.backends.cudnn.enabled

In [None]:
DATA_PATH = Path('C:/fastai/courses/dl2/data/road_damage_dataset')

In [None]:
MULTICLASS_CSV_PATH = DATA_PATH/'mc.csv'

In [None]:
im_size = 299

In [None]:
train_tfms = TransformList([
        RandomScale(im_size, 1.2),
        RandomCrop(im_size),
        RandomLighting(0.05, 0.05),
        RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

val_tfms = TransformList([
        Scale(im_size),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

In [None]:
batch_size = 4
data = ImageData.from_csv(DATA_PATH, MULTICLASS_CSV_PATH, batch_size, train_tfms, val_tfms)
num_classes = 8

In [None]:
model_ft = models.resnet18(pretrained=True)
model_ft.avgpool = nn.AdaptiveAvgPool2d(1)
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Sequential(
    nn.Dropout(p=.5),
    nn.Linear(num_ftrs, num_classes),
    nn.Sigmoid()
)

In [None]:
criterion = nn.BCELoss()
optim_fn = optim.Adam

In [None]:
sess = Session(model_ft, criterion, optim_fn, [*[1e-3] * 9, 1e-2])

In [None]:
sess.freeze() 

In [None]:
lr_find(sess, data['train'], start_lr=1e-5)

In [None]:
sess.set_lr(1e-3)

In [None]:
accuracy = NHotAccuracy(num_classes)
validator = Validator(data['valid'], accuracy)
lr_scheduler = CosAnneal(len(data['train']), T_mult=2)
schedule = TrainingSchedule(data['train'], [validator, lr_scheduler])

sess.train(schedule, 3)

In [None]:
for detail in accuracy.details:
    print(detail)

In [None]:
sess.save('TrainLayer1')

In [None]:
sess.load('TrainLayer1')

In [None]:
sess.unfreeze()

In [None]:
lr_find(sess, data['train'], start_lr=[*[1e-6] * 9, 1e-5])

In [None]:
sess.set_lr([*[1e-5] * 9, 1e-4])

In [None]:
accuracy = NHotAccuracy(num_classes)
validator = Validator(data['valid'], accuracy)
lr_scheduler = CosAnneal(len(data['train']), T_mult=2)
schedule = TrainingSchedule(data['train'], [validator, lr_scheduler])

sess.train(schedule, 7)

In [None]:
for detail in accuracy.details:
    print(detail)

In [None]:
lr_scheduler.plot()

In [None]:
sess.save('FullTrain')

In [None]:
sess.load('FullTrain')

In [None]:
sess.train(schedule, 8)

In [None]:
sess.save('FullTrain2')

In [None]:
sess.load('FullTrain2')