In [None]:
# selecting GPU
%env CUDA_DEVICE_ORDER=PCI_BUS_ID
%env CUDA_VISIBLE_DEVICES=0
%matplotlib notebook

In [None]:
from utils import WingDataset

import os
import time
import json

import torch
from torch import nn
from torchvision import models, transforms

In [None]:
# basic setup stuff
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
run_name = 'test_runs01/001_first_training'
if not os.path.exists(run_name):
    os.makedirs(run_name)

model_name = 'resnet18'
batch_size = 64
epochs = 1

# training loss
criterion = torch.nn.MSELoss().to(device)

# model, replacing last fully connected layer
model = models.resnet18(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 1)
model.to(device)

# optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)



In [None]:
# DATASETS, DATALOADERS

# transforms: converting to pytorch tensor and normalization
# later dataset autgmentation transforms can be added here, but be careful to consider label preservation
img_tr = transforms.Compose([transforms.ToTensor(),
                             transforms.Normalize([0.485, 0.456, 0.406],
                                                  [0.229, 0.224, 0.225])])

# training and validation sets, by default 3 runs are used for test (see WingDataset), rest is training
train_ds = WingDataset(test=False, transforms=img_tr)
val_ds = WingDataset(test=True, transforms=img_tr)

sets = {'train', 'val'}
dsets = {'train': train_ds, 'val': val_ds}

dataloaders = {x: torch.utils.data.DataLoader(dsets[x], batch_size=batch_size, shuffle=True,
                                              num_workers=8, drop_last=True)
               for x in sets}
dataset_sizes = {x: len(dsets[x]) for x in sets}

In [None]:
def train_model(model, criterion, num_epochs, dataloaders, device, optimizer, start_epoch=0):
    since = time.time()
    history = {'train_loss': [],
               'val_loss': [],
               'mean_err': []}
    if start_epoch != 0:
        with open(os.path.join(run_name, 'history.json'), 'r') as fin:
            history = json.load(fin)
    
    
    for eph in range(start_epoch, start_epoch+num_epochs):
        
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode
                
            running_loss = 0.0
            errors = []
            
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device).float()
                optimizer.zero_grad()
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)
                    
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                    else:
                        pass
                running_loss += loss.item() * inputs.size(0)
                    
                
            epoch_loss = float(running_loss / dataset_sizes[phase]) 
            if phase == 'train':
                history['train_loss'].append(epoch_loss)
                print(epoch_loss)
            else:
                history['val_loss'].append(epoch_loss)
                print('\t\t{}'.format(epoch_loss))
            
            if eph % 1 == 0:
                save_path = os.path.join(run_name, 'model_latest.pth')
                torch.save(model, save_path)
                
            with open(os.path.join(run_name, 'history.json'), 'w') as fout:
                json.dump(history, fout)
                
    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    return history, model

In [None]:
h, model = train_model(model, criterion, epochs, dataloaders, device, optimizer)

In [None]:
from matplotlib import pyplot as plt
plt.plot(h['train_loss'], 'k')
plt.plot(h['val_loss'], 'r')

In [None]:
# for inputs, labels in dataloaders['train']:
#     break