In [1]:
from dataset import *
from specific_train import *

import matplotlib.pyplot as plt

import torch, torchvision
import torchvision.transforms as transforms

In [2]:
import warnings

warnings.filterwarnings('ignore')

In [3]:
batch_size = 32


train_transform = transforms.Compose([
    transforms.Resize(256), transforms.RandomCrop((224, 400)), transforms.RandomRotation(10),
    transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

val_transform = transforms.Compose([
    transforms.Resize(256), transforms.CenterCrop((224, 400)),
    transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])


train_dataset = TrainDataset(transform=train_transform)
val_dataset = ValDataset(transform=val_transform)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

In [7]:
Mmodel = Baseline(model='resnet18', num_classes=48, tag='M')
Mmodel.model.load_state_dict(torch.load('./M_7_40.pt')); Mmodel.model.eval()

Hmodel = Baseline(model='resnet18', num_classes=2, tag='H')
Hmodel.model.load_state_dict(torch.load('./H_0_15.pt')); Hmodel.model.eval()

LRBmodel = Baseline(model='resnet18', num_classes=3, tag='LRB')
LRBmodel.model.load_state_dict(torch.load('./LRB_4_120.pt')); LRBmodel.model.eval()


model = Ensemble(Mmodel.model, Hmodel.model, LRBmodel.model, num_classes=len(train_dataset.label_info),
                 print_freq=5, save=False, tag='ensemble')

In [8]:
epochs = 5
lr = 0.01
weight_decay = 0.00001

In [9]:
model.train(train_loader, val_loader, epochs=epochs, lr=lr, weight_decay=weight_decay)

Epoch 1 Started...
Iteration : 1 - Train Loss : 7.607970, Test Loss : 8.121062, Train Acc : 0.000000, Test Acc : 1.019022
Iteration : 6 - Train Loss : 7.613744, Test Loss : 7.045658, Train Acc : 9.375000, Test Acc : 10.326087
Iteration : 11 - Train Loss : 5.501492, Test Loss : 5.480012, Train Acc : 18.750000, Test Acc : 15.421196
Iteration : 16 - Train Loss : 5.121565, Test Loss : 4.288275, Train Acc : 25.000000, Test Acc : 29.551630
Iteration : 21 - Train Loss : 3.701386, Test Loss : 3.297042, Train Acc : 25.000000, Test Acc : 37.567935
Iteration : 26 - Train Loss : 2.955717, Test Loss : 2.521130, Train Acc : 46.875000, Test Acc : 43.274457
Iteration : 31 - Train Loss : 1.682958, Test Loss : 1.934153, Train Acc : 53.125000, Test Acc : 56.929348
Iteration : 36 - Train Loss : 2.501483, Test Loss : 1.521681, Train Acc : 59.375000, Test Acc : 65.285326
Iteration : 41 - Train Loss : 1.436785, Test Loss : 1.101688, Train Acc : 62.500000, Test Acc : 73.165761
Iteration : 46 - Train Loss : 0.

KeyboardInterrupt: 

In [None]:
label_fontsize = 25

plt.figure(figsize=(20, 10))
train_lossline, = plt.plot(model.train_losses, label='Train')
test_lossline, = plt.plot(model.test_losses, color='red', label='Test')
plt.legend(handles=[train_lossline, test_lossline], fontsize=20)
plt.xlabel('Step', fontsize=label_fontsize)
plt.ylabel('Loss', fontsize=label_fontsize)
plt.show()

In [None]:
plt.figure(figsize=(20, 10))
train_accline, = plt.plot(model.train_acc, label='Train')
test_accline, = plt.plot(model.test_acc, color='red', label='Test')
plt.legend(handles=[train_accline, test_accline], fontsize=20)
plt.xlabel('Step', fontsize=label_fontsize)
plt.ylabel('Acc', fontsize=label_fontsize)
plt.show()