In [None]:
import os
import sys
import torch
from tqdm import tqdm
sys.path.insert(0, os.path.abspath('..'))

In [None]:
from includes.models_pretrained import get_unet_model

In [None]:
from includes.dataloader import prepare_loaders

In [None]:
train_data, val_data = prepare_loaders(batchSize = 4, device = "cuda", numWorkers = 2, shuffle = False, csvFile = "../data/train.csv", basePath = '../data/', shape = (200, 200))

In [None]:
from includes.utils.loss import IoULoss

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'

In [None]:
model = get_unet_model("vanilla", 1, 3, None, device)

In [None]:
# HYPERPARAMETERS
EPOCHS = 80
BATCH_SIZE = 32
LR = 0.0005
WORKERS = 2
history = {'loss' : [], 'accuracy' : [], 'val_loss' : [], 'val_accuracy' : [], 'lr' : []}

In [None]:
def validationStats(val_data, loss_fn, score_fn):
    model.eval();
    loop = tqdm(val_data)
    loss_running = 0.0
    accuracy_running = 0.0
    total = loop.total

    for index, (data, target) in enumerate(loop):
    
        with torch.no_grad():
            data = data.to(device=device)
            target = target.to(device=device)
            output = model(data)

            score_cal = score_fn(output, target)
            loss_cal = loss_fn(output, target)

            loss_running += loss_cal.item()
            accuracy_running += score_cal.item()

        # print statistics
        if (index + 1) % total == 0:
            print('Accuracy: %.3f / loss: %.3f' %
                  (accuracy_running / total, loss_running / total))
    model.train();
    return loss_running / total, accuracy_running / total

In [None]:
def train_fn(model, dataloader, optimizer, loss_fn, score_fn):

    loop = tqdm(dataloader)
    loss_running = 0
    accuracy_running = 0

    for index, (data, target) in enumerate(loop):
        data = data.to(device=device)
        target = target.to(device=device)
        output = model(data)
        _loss = loss_fn(output, target)
        optimizer.zero_grad()

        _loss.backward()
        optimizer.step()
        loop.set_postfix(loss=_loss.item())
        
        _score = score_fn(output, target)

        loss_running += _loss.item()
        accuracy_running += _score.item()

        # print statistics
        if (index+1) % loop.total == 0:
            print('Accuracy: %.3f / loss: %.3f' %
                  (accuracy_running / loop.total, loss_running / loop.total))
            running_loss = 0.0
        
    history['loss'].append(loss_running / loop.total)
    history['accuracy'].append(accuracy_running / loop.total)


def train_mask():

    loss_fn = IoULoss()
    score_fn = IoU()
    optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=1e-7)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=80, eta_min=1e-5)

    for epoch in range(EPOCHS):
        train_fn(model, train_data, optimizer, loss_fn, score_fn)
        scheduler.step()

        loss_val, acc_val = validationStats(val_data, loss_fn, score_fn)

        history['val_accuracy'].append(acc_val)
        history['val_loss'].append(loss_val)
        history['lr'].append(LR)

        torch.save(model.state_dict(), f"drive/MyDrive/deepnetwork/trainedModels/{model_name}")

In [None]:
# Train the network

train_mask()

In [None]:
#
## save the history
#

outfile = open(f'drive/MyDrive/deepnetwork/history/{model_name}_{time.time()}','wb')
pickle.dump(history, outfile)
outfile.close()

In [None]:
# 
## validate the results
#

validationStats(val_data, IoULoss(), IoU())

In [None]:
# load saved history

resnet34_infile = open('pretrained_models/history/resnet34History','rb')
resnet34_history = pickle.load(resnet34_infile, encoding='bytes')

resnet50_infile = open('pretrained_models/history/resnet50History','rb')
resnet50_history = pickle.load(resnet50_infile, encoding='bytes')

vgg16_infile = open('pretrained_models/history/vgg16History','rb')
vgg16_history = pickle.load(vgg16_infile, encoding='bytes')

mobilenet_infile = open('pretrained_models/history/mobilenetHistory','rb')
mobilenet_history = pickle.load(mobilenet_infile, encoding='bytes')

efficientnet_b1_infile = open('pretrained_models/history/efficientnetB1History','rb')
efficientnet_b1_history = pickle.load(efficientnet_b1_infile, encoding='bytes')

In [None]:
#
## plot the history
#

import matplotlib.pyplot as plt

plt.figure(figsize=(20, 10))

plt.subplot(1, 2, 1)
plt.plot(resnet34_history['accuracy'])
plt.plot(resnet34_history['val_accuracy'])
plt.plot(resnet34_history['loss'])
plt.plot(resnet34_history['val_loss'])
plt.title('Resnet34 Model')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['accuracy', 'val accuracy', 'loss', 'val loss'], loc='upper left')



plt.subplot(1, 2, 2)
plt.plot(resnet50_history['accuracy'])
plt.plot(resnet50_history['val_accuracy'])
plt.plot(resnet50_history['loss'])
plt.plot(resnet50_history['val_loss'])
plt.title('Resnet50 Model')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['accuracy', 'val accuracy', 'loss', 'val loss'], loc='upper left')

plt.savefig('drive/MyDrive/deepnetwork/plot1.png')

plt.figure(figsize=(20, 7))

plt.subplot(1, 3, 1)
plt.plot(efficientnet_b1_history['accuracy'])
plt.plot(efficientnet_b1_history['val_accuracy'])
plt.plot(efficientnet_b1_history['loss'])
plt.plot(efficientnet_b1_history['val_loss'])
plt.title('Efficientnet-B1 Model')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['accuracy', 'val accuracy', 'loss', 'val loss'], loc='upper left')


plt.subplot(1, 3, 2)
plt.plot(mobilenet_history['accuracy'])
plt.plot(mobilenet_history['val_accuracy'])
plt.plot(mobilenet_history['loss'])
plt.plot(mobilenet_history['val_loss'])
plt.title('Mobilenet_v2 Model')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['accuracy', 'val accuracy', 'loss', 'val loss'], loc='upper left')


plt.subplot(1, 3, 3)
plt.plot(vgg16_history['accuracy'])
plt.plot(vgg16_history['val_accuracy'])
plt.plot(vgg16_history['loss'])
plt.plot(vgg16_history['val_loss'])
plt.title('VGG16 Model')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['accuracy', 'val accuracy', 'loss', 'val loss'], loc='upper left')


plt.savefig('drive/MyDrive/deepnetwork/plot2.png')