In [1]:
import cv2
import PIL
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
import tqdm
import os

from PIL import Image
from torchvision import transforms as T
from torchvision.models.alexnet import AlexNet
from torchvision.models.resnet import resnet18, resnet34, resnet50, resnet101, resnet152
from torch.utils.data import IterableDataset, DataLoader, Dataset
from sklearn.metrics import classification_report, accuracy_score, f1_score, average_precision_score
from sklearn.model_selection import train_test_split, KFold

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'


class BonzDataset(Dataset):
    def __init__(self, data):
        super(BonzDataset, self).__init__()
        self.image_prefixes = data.image_prefixes.values
        self.features = data.features.values
        self.img_tensors = data.img_tensors.values
        if 'labels' in data:
            self.labels = data.labels.values
            self.one_hot_labels = data.one_hot_labels.values
        else:
            self.labels=None
    
    def __len__(self):
        return len(self.image_prefixes)
    
    def __getitem__(self, idx):
        outputs = (self.img_tensors[idx], self.features[idx],)
        if self.labels is not None:
            outputs += (self.one_hot_labels[idx], self.labels[idx])
        return outputs
    
        
class Bonz(nn.Module):
    def __init__(self, hidden_dim=100, feature_selection=12):
        super(Bonz, self).__init__()
        self.resnet = resnet50(True)
        self.BiLSTM = nn.LSTM(2048, hidden_dim, batch_first=True, bidirectional=True)
        self.classifier = nn.Sequential(
            nn.Dropout(0.2),
            nn.Linear(hidden_dim*2 + feature_selection, 4098),
            nn.LeakyReLU(0.001, inplace=True),
            nn.Dropout(0.2),
            nn.Linear(4098, 2048),
            nn.LeakyReLU(0.001, inplace=True),
            nn.Dropout(0.2),
            nn.Linear(2048, 20),
        )
        self.classifier.apply(self._init_weights)
    
    def forward(self, imgs, features=None):
        # Generate featuress from each images
        x = []
        for img in imgs:
            temp = self.do_resnet(img)
            x.append(temp.unsqueeze(1))
        x = torch.cat(x, 1)
        
        # LSTM step
        x, _ = self.BiLSTM(x)
        last_hidden = x[:,-1,:]
        
        # Concate features
        total_features = torch.cat([last_hidden, features], -1)
        predict = self.classifier(total_features)

        return predict
    
    def _init_weights(self, module):
        if isinstance(module, (nn.Linear, nn.Embedding, nn.Conv1d)):
            module.weight.data.normal_(mean=0.0, std=0.02)
            if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
    
    def do_resnet(self, img):
        x = self.resnet.conv1(img)
        x = self.resnet.bn1(x)
        x = self.resnet.relu(x)
        x = self.resnet.maxpool(x)

        x = self.resnet.layer1(x)
        x = self.resnet.layer2(x)
        x = self.resnet.layer3(x)
        x = self.resnet.layer4(x)

        x = self.resnet.avgpool(x)
        x = torch.flatten(x, 1)
        
        return x
    
    def freeze_resnet(self):
        for w in self.resnet.parameters():
            w.requires_grad = False
            
    def unfreeze_resnet(self):
        for w in self.resnet.parameters():
            w.requires_grad = True 

In [2]:
#del model
def get_data(path):
    data = pd.read_csv(path, index_col=0)
    resnet_idx = list(data.columns).index('data_AUTOGRAPHER_RESNET_mean_tench, Tinca tinca')
    new_data = data.iloc[:,:resnet_idx]
    new_data['labels'] = [int(i[-2:])-1 for i in new_data.event_id.values]
    new_data['image_prefixes'] = list(map(lambda x,y,z: str(x)+'_'+str(y)+'_'+str(z), 
                                      new_data.sub_id.values, 
                                      new_data.source.values, 
                                      new_data.event_id.values))
    new_data['one_hot_labels'] = list(map(lambda x: nn.functional.one_hot(torch.tensor(x), 20).float(), new_data.labels.values))
    new_data['features'] = [torch.tensor(i).float() for i in new_data.iloc[:, 3:-3].values]
    new_data['img_tensors'] = get_img_tensors(new_data['image_prefixes'].values)
    return new_data.iloc[:, -5:]


def get_test_data(path):
    data = pd.read_csv(path)
    new_data = data.iloc[:,0:16]
    new_data['labels'] = [i-1 for i in new_data.action_id.values]
    new_data['image_prefixes'] = list(map(lambda x,y: str(x)+'_'+str(y), 
                                      new_data.sub_id.values, 
                                      new_data.event_id.values))
    new_data['one_hot_labels'] = list(map(lambda x: nn.functional.one_hot(torch.tensor(x), 20).float(), new_data.labels.values))
    new_data['features'] = [torch.tensor(i).float() for i in new_data.iloc[:, 4:16].values]
    new_data['img_tensors'] = get_img_tensors(new_data['image_prefixes'].values)
    return new_data.iloc[:, -5:]


def get_img_tensors(image_prefixes):
    transform = T.Compose([
        T.Resize(256),
        T.CenterCrop(224),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    temp = []
    
    for img_prefix in tqdm.notebook.tqdm(image_prefixes):
        img_paths = []
        i = 0
        img_name = img_prefix+'_'+str(i)+'.jpg'
        while img_name in os.listdir('./OUTPUT_MERGED/AUTOGRAPHER/'):
            img_paths.append('./OUTPUT_MERGED/AUTOGRAPHER/'+img_name)
            i += 1
            img_name = img_prefix+'_'+str(i)+'.jpg'

        # Transform images to tensors
        img_tensors = []
        for path in img_paths:
            img = Image.open(path)
            img_tensors.append(transform(img))

        temp.append(img_tensors)

    return temp

        
train_data = get_data('min_max.csv')
#test_data = get_test_data('test_normalization(delete).csv')


model = Bonz(hidden_dim=2048, feature_selection=train_data.features[0].shape[0])
origin_state_dict = model.state_dict()

HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))




# 10-Folds CV

In [3]:
# Train with UNFREEZE ResNet
model.unfreeze_resnet()

model.to(DEVICE)
last_predicts = []
last_predict_matrix = []

kF = KFold(n_splits=10)
for train_idx, val_idx in tqdm.notebook.tqdm(kF.split(train_data.labels.values), desc='K-Fold'):
    train_dataset = BonzDataset(train_data.iloc[train_idx])
    val_dataset = BonzDataset(train_data.iloc[val_idx])
    
    # Training Phase
    model.train()
    model.load_state_dict(origin_state_dict) #Load origin state dict
    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
    
    for e in tqdm.notebook.trange(20):
        s = time.time()
        total_loss = 0
        predicts = []
        y_true = []
        data_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
        for img_tensors, features, one_hot_label, label in data_loader:
            img_tensors = [ts.to(DEVICE) for ts in img_tensors]
            features = features.to(DEVICE)
            one_hot_label = one_hot_label.to(DEVICE)
            label = label.to(DEVICE)

            optimizer.zero_grad()
            predict = model(img_tensors, features)
            loss = loss_fn(predict, label)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            predict = predict.detach().cpu()
            predict = torch.argmax(predict, 1)
            predicts.extend(predict.tolist())
            y_true.extend(label.detach().cpu().tolist())
        
    acc_score = accuracy_score(y_true, predicts)
    f1_scr = f1_score(y_true, predicts, average='micro')

    print(f'Loss = {total_loss:.2f},\t Train_acc={acc_score:.2f},\t Train_f1={f1_scr:.2f},\t Time = {time.time()-s:.2f}')
    
    # Validation Phase
    model.eval()
    val_data_loader = DataLoader(val_dataset)
    for img_tensors, features, one_hot_label, label in val_data_loader:
        img_tensors = [ts.to(DEVICE) for ts in img_tensors]
        features = features.to(DEVICE)
        one_hot_label = one_hot_label.to(DEVICE)
        label = label.to(DEVICE)

        with torch.no_grad():
            predict = model(img_tensors, features)
        predict = predict.detach().cpu()
        last_predict_matrix.extend(predict)
        predict = torch.argmax(predict, 1)
        last_predicts.extend(predict.tolist())
        
print(classification_report(train_data.labels.values, last_predicts))

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='K-Fold', max=1.0, style=ProgressStyle(d…

HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


Loss = 6.27,	 Train_acc=0.97,	 Train_f1=0.97,	 Time = 20.46


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


Loss = 11.70,	 Train_acc=0.92,	 Train_f1=0.92,	 Time = 20.42


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


Loss = 4.73,	 Train_acc=0.97,	 Train_f1=0.97,	 Time = 19.45


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


Loss = 7.33,	 Train_acc=0.96,	 Train_f1=0.96,	 Time = 18.76


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


Loss = 4.14,	 Train_acc=0.99,	 Train_f1=0.99,	 Time = 19.41


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


Loss = 6.84,	 Train_acc=0.96,	 Train_f1=0.96,	 Time = 19.00


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


Loss = 6.88,	 Train_acc=0.96,	 Train_f1=0.96,	 Time = 18.83


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


Loss = 8.40,	 Train_acc=0.94,	 Train_f1=0.94,	 Time = 19.31


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


Loss = 9.40,	 Train_acc=0.94,	 Train_f1=0.94,	 Time = 18.70


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


Loss = 8.19,	 Train_acc=0.96,	 Train_f1=0.96,	 Time = 18.94

              precision    recall  f1-score   support

           0       0.60      0.60      0.60        10
           1       0.67      0.60      0.63        10
           2       0.70      0.70      0.70        10
           3       0.21      0.30      0.25        10
           4       0.71      0.50      0.59        10
           5       0.60      0.60      0.60        10
           6       0.57      0.40      0.47        10
           7       0.82      0.90      0.86        10
           8       0.40      0.40      0.40        10
           9       0.70      0.70      0.70        10
          10       0.58      0.70      0.64        10
          11       0.62      0.80      0.70        10
          12       0.78      0.70      0.74        10
          13       0.36      0.40      0.38        10
          14       0.17      0.10      0.12        10
          15       0.11      0.10      0.11        10
          16       

# Leave One Out CV

In [6]:
# Train with UNFREEZE ResNet
model.unfreeze_resnet()

model.to(DEVICE)
#last_predicts = []
#last_predict_matrix = []

for epoch in range(111,200):
    # Split data leave-one-out
    idx = np.arange(200)
    idx = np.delete(idx, epoch)
    train_dataset = BonzDataset(train_data.iloc[idx])
    val_dataset = BonzDataset(train_data.iloc[[epoch]])
    
    # Training Phase
    model.train()
    model.load_state_dict(origin_state_dict) #Load origin state dict
    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
    
    for e in tqdm.notebook.trange(20):
        s = time.time()
        total_loss = 0
        predicts = []
        y_true = []
        data_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
        for img_tensors, features, one_hot_label, label in data_loader:
            img_tensors = [ts.to(DEVICE) for ts in img_tensors]
            features = features.to(DEVICE)
            one_hot_label = one_hot_label.to(DEVICE)
            label = label.to(DEVICE)

            optimizer.zero_grad()
            predict = model(img_tensors, features)
            loss = loss_fn(predict, label)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            predict = predict.detach().cpu()
            predict = torch.argmax(predict, 1)
            predicts.extend(predict.tolist())
            y_true.extend(label.detach().cpu().tolist())
        
    acc_score = accuracy_score(y_true, predicts)
    f1_scr = f1_score(y_true, predicts, average='micro')

    print(f'LOOCV = {epoch},\t Loss = {total_loss:.2f},\t Train_acc={acc_score:.2f},\t Train_f1={f1_scr:.2f},\t Time = {time.time()-s:.2f}')
    
    # Validation Phase
    model.eval()
    val_data_loader = DataLoader(val_dataset)
    for img_tensors, features, one_hot_label, label in val_data_loader:
        img_tensors = [ts.to(DEVICE) for ts in img_tensors]
        features = features.to(DEVICE)
        one_hot_label = one_hot_label.to(DEVICE)
        label = label.to(DEVICE)

        with torch.no_grad():
            predict = model(img_tensors, features)
        predict = predict.detach().cpu()
        last_predict_matrix.extend(predict)
        predict = torch.argmax(predict, 1)
        last_predicts.extend(predict.tolist())
        
print(classification_report(train_data.labels.values, last_predicts))

HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


LOOCV = 111,	 Loss = 12.01,	 Train_acc=0.92,	 Train_f1=0.92,	 Time = 20.95


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


LOOCV = 112,	 Loss = 7.47,	 Train_acc=0.96,	 Train_f1=0.96,	 Time = 20.94


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


LOOCV = 113,	 Loss = 6.88,	 Train_acc=0.95,	 Train_f1=0.95,	 Time = 21.13


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


LOOCV = 114,	 Loss = 9.66,	 Train_acc=0.95,	 Train_f1=0.95,	 Time = 21.08


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


LOOCV = 115,	 Loss = 6.68,	 Train_acc=0.97,	 Train_f1=0.97,	 Time = 21.34


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


LOOCV = 116,	 Loss = 8.46,	 Train_acc=0.93,	 Train_f1=0.93,	 Time = 21.08


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


LOOCV = 117,	 Loss = 5.75,	 Train_acc=0.96,	 Train_f1=0.96,	 Time = 21.12


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


LOOCV = 118,	 Loss = 7.09,	 Train_acc=0.96,	 Train_f1=0.96,	 Time = 20.88


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


LOOCV = 119,	 Loss = 5.51,	 Train_acc=0.96,	 Train_f1=0.96,	 Time = 22.91


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


LOOCV = 120,	 Loss = 5.18,	 Train_acc=0.97,	 Train_f1=0.97,	 Time = 21.59


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


LOOCV = 121,	 Loss = 6.34,	 Train_acc=0.97,	 Train_f1=0.97,	 Time = 21.23


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


LOOCV = 122,	 Loss = 10.45,	 Train_acc=0.93,	 Train_f1=0.93,	 Time = 21.06


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


LOOCV = 123,	 Loss = 7.28,	 Train_acc=0.96,	 Train_f1=0.96,	 Time = 22.92


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


LOOCV = 124,	 Loss = 6.24,	 Train_acc=0.97,	 Train_f1=0.97,	 Time = 20.80


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


LOOCV = 125,	 Loss = 7.93,	 Train_acc=0.95,	 Train_f1=0.95,	 Time = 20.99


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


LOOCV = 126,	 Loss = 10.94,	 Train_acc=0.94,	 Train_f1=0.94,	 Time = 21.14


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


LOOCV = 127,	 Loss = 9.23,	 Train_acc=0.94,	 Train_f1=0.94,	 Time = 21.37


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


LOOCV = 128,	 Loss = 2.90,	 Train_acc=0.99,	 Train_f1=0.99,	 Time = 20.83


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


LOOCV = 129,	 Loss = 10.96,	 Train_acc=0.96,	 Train_f1=0.96,	 Time = 21.00


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


LOOCV = 130,	 Loss = 8.87,	 Train_acc=0.96,	 Train_f1=0.96,	 Time = 21.10


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


LOOCV = 131,	 Loss = 7.01,	 Train_acc=0.97,	 Train_f1=0.97,	 Time = 21.14


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


LOOCV = 132,	 Loss = 3.70,	 Train_acc=0.98,	 Train_f1=0.98,	 Time = 21.25


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


LOOCV = 133,	 Loss = 6.16,	 Train_acc=0.97,	 Train_f1=0.97,	 Time = 21.08


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


LOOCV = 134,	 Loss = 5.45,	 Train_acc=0.97,	 Train_f1=0.97,	 Time = 20.89


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


LOOCV = 135,	 Loss = 8.34,	 Train_acc=0.94,	 Train_f1=0.94,	 Time = 20.75


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


LOOCV = 136,	 Loss = 6.41,	 Train_acc=0.96,	 Train_f1=0.96,	 Time = 21.47


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


LOOCV = 137,	 Loss = 10.10,	 Train_acc=0.93,	 Train_f1=0.93,	 Time = 21.39


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


LOOCV = 138,	 Loss = 6.07,	 Train_acc=0.98,	 Train_f1=0.98,	 Time = 20.45


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


LOOCV = 139,	 Loss = 6.08,	 Train_acc=0.97,	 Train_f1=0.97,	 Time = 20.94


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


LOOCV = 140,	 Loss = 13.62,	 Train_acc=0.93,	 Train_f1=0.93,	 Time = 20.92


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


LOOCV = 141,	 Loss = 7.32,	 Train_acc=0.96,	 Train_f1=0.96,	 Time = 20.90


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


LOOCV = 142,	 Loss = 6.87,	 Train_acc=0.96,	 Train_f1=0.96,	 Time = 20.72


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


LOOCV = 143,	 Loss = 6.11,	 Train_acc=0.96,	 Train_f1=0.96,	 Time = 20.96


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


LOOCV = 144,	 Loss = 5.93,	 Train_acc=0.98,	 Train_f1=0.98,	 Time = 21.11


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


LOOCV = 145,	 Loss = 6.51,	 Train_acc=0.96,	 Train_f1=0.96,	 Time = 20.97


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


LOOCV = 146,	 Loss = 3.56,	 Train_acc=0.99,	 Train_f1=0.99,	 Time = 20.74


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


LOOCV = 147,	 Loss = 6.79,	 Train_acc=0.96,	 Train_f1=0.96,	 Time = 20.44


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


LOOCV = 148,	 Loss = 8.10,	 Train_acc=0.95,	 Train_f1=0.95,	 Time = 20.76


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


LOOCV = 149,	 Loss = 7.17,	 Train_acc=0.96,	 Train_f1=0.96,	 Time = 21.18


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


LOOCV = 150,	 Loss = 6.93,	 Train_acc=0.96,	 Train_f1=0.96,	 Time = 21.16


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


LOOCV = 151,	 Loss = 6.03,	 Train_acc=0.96,	 Train_f1=0.96,	 Time = 20.88


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


LOOCV = 152,	 Loss = 9.53,	 Train_acc=0.95,	 Train_f1=0.95,	 Time = 21.14


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


LOOCV = 153,	 Loss = 5.16,	 Train_acc=0.98,	 Train_f1=0.98,	 Time = 20.80


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


LOOCV = 154,	 Loss = 12.80,	 Train_acc=0.94,	 Train_f1=0.94,	 Time = 20.86


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


LOOCV = 155,	 Loss = 12.64,	 Train_acc=0.93,	 Train_f1=0.93,	 Time = 21.04


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


LOOCV = 156,	 Loss = 11.71,	 Train_acc=0.93,	 Train_f1=0.93,	 Time = 21.90


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


LOOCV = 157,	 Loss = 11.29,	 Train_acc=0.92,	 Train_f1=0.92,	 Time = 20.92


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


LOOCV = 158,	 Loss = 4.85,	 Train_acc=0.98,	 Train_f1=0.98,	 Time = 20.73


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


LOOCV = 159,	 Loss = 4.87,	 Train_acc=0.98,	 Train_f1=0.98,	 Time = 20.93


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


LOOCV = 160,	 Loss = 8.14,	 Train_acc=0.94,	 Train_f1=0.94,	 Time = 21.11


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


LOOCV = 161,	 Loss = 4.78,	 Train_acc=0.97,	 Train_f1=0.97,	 Time = 21.38


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


LOOCV = 162,	 Loss = 6.83,	 Train_acc=0.96,	 Train_f1=0.96,	 Time = 20.97


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


LOOCV = 163,	 Loss = 11.96,	 Train_acc=0.94,	 Train_f1=0.94,	 Time = 21.38


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


LOOCV = 164,	 Loss = 6.34,	 Train_acc=0.97,	 Train_f1=0.97,	 Time = 21.76


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


LOOCV = 165,	 Loss = 6.01,	 Train_acc=0.98,	 Train_f1=0.98,	 Time = 21.23


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


LOOCV = 166,	 Loss = 8.76,	 Train_acc=0.96,	 Train_f1=0.96,	 Time = 21.18


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


LOOCV = 167,	 Loss = 4.56,	 Train_acc=0.99,	 Train_f1=0.99,	 Time = 22.53


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


LOOCV = 168,	 Loss = 3.59,	 Train_acc=0.98,	 Train_f1=0.98,	 Time = 21.30


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


LOOCV = 169,	 Loss = 7.24,	 Train_acc=0.96,	 Train_f1=0.96,	 Time = 21.26


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


LOOCV = 170,	 Loss = 6.47,	 Train_acc=0.98,	 Train_f1=0.98,	 Time = 22.02


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


LOOCV = 171,	 Loss = 9.06,	 Train_acc=0.95,	 Train_f1=0.95,	 Time = 22.52


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


LOOCV = 172,	 Loss = 9.80,	 Train_acc=0.93,	 Train_f1=0.93,	 Time = 22.71


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


LOOCV = 173,	 Loss = 6.93,	 Train_acc=0.98,	 Train_f1=0.98,	 Time = 21.76


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


LOOCV = 174,	 Loss = 9.26,	 Train_acc=0.95,	 Train_f1=0.95,	 Time = 21.63


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


LOOCV = 175,	 Loss = 5.25,	 Train_acc=0.97,	 Train_f1=0.97,	 Time = 21.89


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


LOOCV = 176,	 Loss = 6.34,	 Train_acc=0.96,	 Train_f1=0.96,	 Time = 22.03


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


LOOCV = 177,	 Loss = 5.49,	 Train_acc=0.97,	 Train_f1=0.97,	 Time = 22.12


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


LOOCV = 178,	 Loss = 6.62,	 Train_acc=0.96,	 Train_f1=0.96,	 Time = 21.64


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


LOOCV = 179,	 Loss = 5.53,	 Train_acc=0.96,	 Train_f1=0.96,	 Time = 21.83


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


LOOCV = 180,	 Loss = 5.33,	 Train_acc=0.96,	 Train_f1=0.96,	 Time = 21.24


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


LOOCV = 181,	 Loss = 6.41,	 Train_acc=0.97,	 Train_f1=0.97,	 Time = 22.36


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


LOOCV = 182,	 Loss = 8.52,	 Train_acc=0.94,	 Train_f1=0.94,	 Time = 21.84


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


LOOCV = 183,	 Loss = 5.21,	 Train_acc=0.97,	 Train_f1=0.97,	 Time = 21.49


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


LOOCV = 184,	 Loss = 9.06,	 Train_acc=0.96,	 Train_f1=0.96,	 Time = 20.75


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


LOOCV = 185,	 Loss = 5.74,	 Train_acc=0.97,	 Train_f1=0.97,	 Time = 22.51


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


LOOCV = 186,	 Loss = 5.02,	 Train_acc=0.97,	 Train_f1=0.97,	 Time = 21.90


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


LOOCV = 187,	 Loss = 15.55,	 Train_acc=0.91,	 Train_f1=0.91,	 Time = 22.95


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


LOOCV = 188,	 Loss = 13.20,	 Train_acc=0.92,	 Train_f1=0.92,	 Time = 21.92


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


LOOCV = 189,	 Loss = 16.26,	 Train_acc=0.90,	 Train_f1=0.90,	 Time = 21.42


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


LOOCV = 190,	 Loss = 9.29,	 Train_acc=0.93,	 Train_f1=0.93,	 Time = 22.51


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


LOOCV = 191,	 Loss = 5.90,	 Train_acc=0.98,	 Train_f1=0.98,	 Time = 21.67


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


LOOCV = 192,	 Loss = 7.03,	 Train_acc=0.96,	 Train_f1=0.96,	 Time = 21.50


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


LOOCV = 193,	 Loss = 5.38,	 Train_acc=0.97,	 Train_f1=0.97,	 Time = 24.45


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


LOOCV = 194,	 Loss = 5.69,	 Train_acc=0.97,	 Train_f1=0.97,	 Time = 27.83


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


LOOCV = 195,	 Loss = 10.10,	 Train_acc=0.95,	 Train_f1=0.95,	 Time = 26.98


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


LOOCV = 196,	 Loss = 6.26,	 Train_acc=0.98,	 Train_f1=0.98,	 Time = 27.32


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


LOOCV = 197,	 Loss = 9.96,	 Train_acc=0.94,	 Train_f1=0.94,	 Time = 27.55


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


LOOCV = 198,	 Loss = 8.92,	 Train_acc=0.96,	 Train_f1=0.96,	 Time = 21.92


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


LOOCV = 199,	 Loss = 9.37,	 Train_acc=0.95,	 Train_f1=0.95,	 Time = 22.49
              precision    recall  f1-score   support

           0       0.82      0.90      0.86        10
           1       0.38      0.50      0.43        10
           2       0.80      0.80      0.80        10
           3       0.33      0.20      0.25        10
           4       0.56      0.50      0.53        10
           5       0.50      0.60      0.55        10
           6       0.57      0.40      0.47        10
           7       0.89      0.80      0.84        10
           8       0.62      0.50      0.56        10
           9       0.69      0.90      0.78        10
          10       0.50      0.60      0.55        10
          11       0.44      0.40      0.42        10
          12       0.89      0.80      0.84        10
          13       0.11      0.10      0.11        10
          14       0.25      0.30      0.27        10
          15       0.08      0.10      0.09        10
      

In [17]:
'''
# Train with freeze ResNet
model.freeze_resnet()



loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

model.train()
model.to(DEVICE)
for epoch in range(20):
    s = time.time()
    print('Epoch: ', epoch, '\n--------------------------')
    predicts = []
    total_loss = 0
    data_s = time.time()
    y_true = []
    
    data_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
    for img_tensors, features, one_hot_label, label in tqdm.notebook.tqdm(data_loader):
        img_tensors = [ts.to(DEVICE) for ts in img_tensors]
        features = features.to(DEVICE)
        one_hot_label = one_hot_label.to(DEVICE)
        label = label.to(DEVICE)
        
        optimizer.zero_grad()
        predict = model(img_tensors, features)
        loss = loss_fn(predict, label)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        predict = predict.detach().cpu()
        predict = torch.argmax(predict, 1)
        predicts.extend(predict.tolist())
        y_true.extend(label.detach().cpu().tolist())

    print('Loss = ', total_loss)
    print('Time = ', time.time() - s)
    print('Accuracy Score: ', accuracy_score(y_true, predicts))
    print('F1 Score: ', f1_score(y_true, predicts, average='micro'))
    #print(classification_report(train_data.labels.values, predicts))
    print('___________________________________________')
'''

Epoch:  0 
--------------------------


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))


Loss =  50.18169403076172
Time =  4.078999757766724
Accuracy Score:  0.04
F1 Score:  0.04
___________________________________________
Epoch:  1 
--------------------------


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))


Loss =  39.51239991188049
Time =  4.244995594024658
Accuracy Score:  0.065
F1 Score:  0.065
___________________________________________
Epoch:  2 
--------------------------


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))


Loss =  37.20724558830261
Time =  3.983999252319336
Accuracy Score:  0.09
F1 Score:  0.09
___________________________________________
Epoch:  3 
--------------------------


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))


Loss =  34.04769706726074
Time =  4.074998140335083
Accuracy Score:  0.125
F1 Score:  0.125
___________________________________________
Epoch:  4 
--------------------------


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))


Loss =  31.932181358337402
Time =  4.1951024532318115
Accuracy Score:  0.185
F1 Score:  0.185
___________________________________________
Epoch:  5 
--------------------------


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))


Loss =  30.395508289337158
Time =  3.9939982891082764
Accuracy Score:  0.185
F1 Score:  0.185
___________________________________________
Epoch:  6 
--------------------------


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))


Loss =  26.694459676742554
Time =  4.0289998054504395
Accuracy Score:  0.29
F1 Score:  0.29
___________________________________________
Epoch:  7 
--------------------------


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))


Loss =  24.283324718475342
Time =  4.420866250991821
Accuracy Score:  0.31
F1 Score:  0.31
___________________________________________
Epoch:  8 
--------------------------


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))


Loss =  22.59529173374176
Time =  4.105997562408447
Accuracy Score:  0.36
F1 Score:  0.36
___________________________________________
Epoch:  9 
--------------------------


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))


Loss =  21.076433777809143
Time =  4.010000228881836
Accuracy Score:  0.42
F1 Score:  0.41999999999999993
___________________________________________
Epoch:  10 
--------------------------


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))


Loss =  20.681477069854736
Time =  4.41600227355957
Accuracy Score:  0.43
F1 Score:  0.42999999999999994
___________________________________________
Epoch:  11 
--------------------------


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))


Loss =  20.630675673484802
Time =  4.285999059677124
Accuracy Score:  0.42
F1 Score:  0.41999999999999993
___________________________________________
Epoch:  12 
--------------------------


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))


Loss =  18.130861282348633
Time =  4.12399959564209
Accuracy Score:  0.495
F1 Score:  0.495
___________________________________________
Epoch:  13 
--------------------------


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))


Loss =  14.896356344223022
Time =  4.128995895385742
Accuracy Score:  0.525
F1 Score:  0.525
___________________________________________
Epoch:  14 
--------------------------


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))


Loss =  15.043651521205902
Time =  4.213002681732178
Accuracy Score:  0.545
F1 Score:  0.545
___________________________________________
Epoch:  15 
--------------------------


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))


Loss =  16.906999707221985
Time =  4.260891675949097
Accuracy Score:  0.51
F1 Score:  0.51
___________________________________________
Epoch:  16 
--------------------------


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))


Loss =  14.9858118891716
Time =  4.142995119094849
Accuracy Score:  0.555
F1 Score:  0.555
___________________________________________
Epoch:  17 
--------------------------


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))


Loss =  13.611359715461731
Time =  4.264477729797363
Accuracy Score:  0.615
F1 Score:  0.615
___________________________________________
Epoch:  18 
--------------------------


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))


Loss =  14.089218556880951
Time =  4.068522214889526
Accuracy Score:  0.58
F1 Score:  0.58
___________________________________________
Epoch:  19 
--------------------------


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))


Loss =  14.25866836309433
Time =  4.2250001430511475
Accuracy Score:  0.6
F1 Score:  0.6
___________________________________________


In [14]:
'''
data_loader = DataLoader(test_dataset, batch_size=100)

model.eval()
model.to(DEVICE)

s = time.time()
predicts = []
for img_tensors, features, one_hot_label, label in tqdm.notebook.tqdm(data_loader):
    img_tensors = [ts.to(DEVICE) for ts in img_tensors]
    features = features.to(DEVICE)
    one_hot_label = one_hot_label.to(DEVICE)
    label = label.to(DEVICE)

    with torch.no_grad():
        predict = model(img_tensors, features)
    predict = predict.detach().cpu()
    predict = torch.argmax(predict, 1)
    predicts.extend(predict.tolist())

print('Time = ', time.time() - s)
print('Accuracy Score: ', accuracy_score(test_data.labels.values, predicts))
print('F1 Score: ', f1_score(test_data.labels.values, predicts, average='micro'))
print(classification_report(test_data.labels.values, predicts))
print('___________________________________________')
'''

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


Time =  1.264162540435791
Accuracy Score:  0.61
F1 Score:  0.61
              precision    recall  f1-score   support

           0       0.80      0.80      0.80         5
           1       1.00      0.20      0.33         5
           2       0.40      0.80      0.53         5
           3       0.20      0.20      0.20         5
           4       1.00      0.80      0.89         5
           5       1.00      1.00      1.00         5
           6       0.50      0.60      0.55         5
           7       0.60      0.60      0.60         5
           8       1.00      0.40      0.57         5
           9       0.83      1.00      0.91         5
          10       0.62      1.00      0.77         5
          11       0.67      0.40      0.50         5
          12       0.80      0.80      0.80         5
          13       0.33      0.40      0.36         5
          14       0.50      0.40      0.44         5
          15       0.00      0.00      0.00         5
          16    

In [4]:
onehot_true = np.array(list(i.numpy() for i in train_data.one_hot_labels.values))
onehot_predict = torch.stack(last_predict_matrix).numpy()
APs = list(average_precision_score(onehot_true[:,i], onehot_predict[:,i]) for i in range(20))
print(np.array(APs))
print(sum(APs)/len(APs))

[0.91075269 0.62545954 0.80991592 0.25598213 0.62484393 0.56852161
 0.6023078  0.85962197 0.58594148 0.71751927 0.63595238 0.54438226
 0.80210554 0.49356448 0.23229472 0.16596729 0.72654457 0.27764698
 0.40008063 0.74716446]
0.579328482939334
