<a href="https://colab.research.google.com/github/debemdeboas/pucrs-aprendizado-de-maquina-t2/blob/master/PUCRS_Aprendizado_de_M%C3%A1quina_T2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import random

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import requests
import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import tensorflow as tf
from PIL import Image
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Dataset
from torchvision import datasets
from torchvision import transforms as T
from torchvision.transforms import functional as TF
from tqdm import tqdm
from enum import IntEnum

Download and extract our dataset.

This will download a tarred file and extract it into `dist`. Then, we're renaming it to `animes/`.
This directory contains the following files:
- `animes.csv`, the CSV containing anime IDs, URLs, titles, genres, and poster path
- `animes.pkl`, serialized (pickled) list of `Anime` instances. This isn't used by this notebook
- `images/`, a directory that contains all of our anime posters as `images/<mal_id>.jpg` files


In [2]:
ds = requests.get("https://public-s3.debem.dev/anime_dataset.tar.xz", allow_redirects=True)

with open("anime_dataset.tar.xz", "wb") as f:
    f.write(ds.content)

!tar xf anime_dataset.tar.xz

In [3]:
df = pd.read_csv("animes.csv")
df = df.dropna()
df["genres"] = df["genres"].str.split('|')

df["img"] = df["img_path"].map(lambda x: Image.open(x).convert("RGB"))

In [154]:
all_genres_to_idx = dict()
all_genres_amnt = dict()
all_genres = list()
for gl in df.genres:
    for g in gl:
        if g not in all_genres_to_idx:
            all_genres.append(g)
            all_genres_to_idx[g] = len(all_genres) - 1
            all_genres_amnt[g] = 0
        all_genres_amnt[g] += 1

In [54]:
class AccIdx(IntEnum):
    TP = 0
    FP = 1
    TN = 2
    FN = 3

class PosterMultiLabelDataset(Dataset):
    def __init__(self, df: pd.DataFrame, transform=None, *args, **kwargs):
        self.df = df
        if transform:
            self.transform = transform
        else:
            self.transform = T.Compose([
                T.RandomHorizontalFlip(p=0.5),
                T.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0),
                T.Resize((256, 256)),
                T.RandomResizedCrop(224),
                T.ToTensor(),
                T.Normalize(mean=[0.5, 0.5, 0.5],
                             std=[0.5, 0.5, 0.5]),
                ])

    def __len__(self) -> int:
        return len(self.df)

    def __getitem__(self, idx):
        # img_path = self.df.loc[idx].img_path
        # img = Image.open(img_path)
        # img = img.convert('RGB')
        # return self.transform(img), torch.Tensor([1 if g in self.df.loc[idx].genres else 0 for g in all_genres])
        # return {
        #     "image": self.transform(img),
        #     "labels": torch.Tensor([1 if g in self.df.loc[idx].genres else 0 for g in all_genres])
        # }
        return self.transform(self.df.loc[idx].img), torch.Tensor([1 if g in self.df.loc[idx].genres else 0 for g in all_genres])

Now let's get some transfer learning done.

We'll use a pre-trained convolutional network to analyze the posters to define which genres a given anime belongs to.
Each anime can belong to any number of genres.


In [29]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

def confusion_matrix(model, loader):
    model.eval()
    confusion_matrix = np.zeros((len(all_genres), len(all_genres)))
    with torch.no_grad():
        for (img, lbl) in loader:
            img = img.to(device)
            lbl = lbl.to(device)
            output = model(img)
            predictions = output > 0.5
                # confusion_matrix[idx_true][idx_pred] += int(predictions[i][j].item())
            # list(zip(zip(ls[0].tolist(), predictions[0].tolist()), all_genres))
            for ll, pp in zip(lbl, predictions):
                for i, (l, p) in enumerate(zip(ll, pp)):
                    print(all_genres[i], l.item(), p.item())
                    break
                break
                # confusion_matrix[idx_true][idx_pred] += 1
    ax = sns.heatmap(confusion_matrix, annot=True, cmap='Blues', fmt='g')
    ax.set_xlabel('Predicted')
    ax.set_ylabel('Label')
    return ax


def validation(model, loader, criterion):
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for (images, labels) in loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            loss = criterion(outputs,labels)
            val_loss +=loss
    return val_loss/len(loader)


def train(model, trainloader, testloader, optimizer, criterion, epochs):
    for epoch in range(epochs):
        model.train()
        running_loss = 0
        for i, (images, labels) in tqdm(enumerate(trainloader)):
            images = images.to(device)
            labels = labels.to(device)
            model.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        val_loss = validation(model, testloader, criterion)
        print(f'Epoch: {epoch+1} | Loss: {running_loss/len(trainloader)} | Val Loss: {val_loss}')


def accuracy(model, loader, positive_threshold = 0):
    model.eval()
    correct = 0
    total = 0
    # Create a numpy matrix of $len(all_genres) x 4$ à la confusion matrix
    totals = np.zeros((len(all_genres), 4))
    with torch.no_grad():
        for (img, lbl) in loader:
            img = img.to(device)
            lbl = lbl.to(device)
            output = model(img)
            tp = (output > positive_threshold) & (lbl > 0)
            fp = (output > positive_threshold) & (lbl == 0)
            tn = (output < positive_threshold) & (lbl == 0)
            fn = (output < positive_threshold) & (lbl > 0)
            stacked = torch.stack((tp, fp, tn, fn), dim=2)
            totals += torch.sum(stacked, dim=0).cpu().numpy()
    return totals


In [7]:
# from torchvision.models import resnet50, ResNet50_Weights
# from torchvision.models import regnet_y_16gf, RegNet_Y_16GF_Weights
from torchvision.models import regnet_x_3_2gf, RegNet_X_3_2GF_Weights

torch.backends.cudnn.benchmark = True # speed up training by using the inbuilt cudnn auto-tuner

if False:
    # model = resnet50(weights=ResNet50_Weights.DEFAULT)
    # model.fc = nn.Linear(2048, len(all_genres))

    model = regnet_x_3_2gf(weights=RegNet_X_3_2GF_Weights.DEFAULT)
    model.fc = nn.Linear(1008, len(all_genres))

    for name, params in model.named_parameters():
        if name not in ('fc.weight', 'fc.bias'): #final block
            params.requires_grad = False
else:
    # download pre-trained version from our bucket:
    model_f = requests.get("https://public-s3.debem.dev/model.pt", allow_redirects=True)
    with open('model.pt', 'wb') as f:
        f.write(model_f.content)
    model = torch.load('model.pt')

model.to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.AdamW(model.parameters(), lr=0.01)
epochs = 6

In [8]:
# dados de treino = 70%, validação = 15%, teste = 15%
trainData = df.sample(frac=0.7)
trainDataLeftover = df.drop(trainData.index)
validationData = trainDataLeftover.sample(frac=0.5)
testData = trainDataLeftover.drop(validationData.index)

testData = testData.reset_index()
validationData = validationData.reset_index()
trainData = trainData.reset_index()

In [54]:
trainloader = DataLoader(PosterMultiLabelDataset(df=trainData), batch_size=124, shuffle=True, pin_memory=True, num_workers=8)
testloader = DataLoader(PosterMultiLabelDataset(df=testData), batch_size=124, shuffle=False, pin_memory=True, num_workers=8)

In [None]:
train(model, trainloader, testloader, optimizer, criterion, epochs)

In [9]:
def predict(model, data: pd.Series, transform):
    img = Image.open(data["img_path"])
    img = transform(img)
    prediction = model(torch.unsqueeze(img, 0).to(device))
    print(prediction)
    return prediction

In [208]:
torch.save(model, 'model.pt')

In [200]:
transforms_val = T.Compose([
    T.Resize((224, 224)),
    T.ToTensor(),
    T.Normalize(mean=[0.5, 0.5, 0.5],
                std=[0.5, 0.5, 0.5]),
    ])

valloader = DataLoader(
    PosterMultiLabelDataset(
        df=validationData,
        transform=transforms_val),
    batch_size=50,
    shuffle=False)
acc = accuracy(model, valloader, positive_threshold=1)

In [141]:
def get_acc_info(acc):
    total_pos = np.sum(acc[:2])
    total_neg = np.sum(acc[2:4])
    with np.errstate(divide='ignore', invalid='ignore'):
        return (
            np.nan_to_num(np.divide(np.sum(acc[AccIdx.TP]), total_pos)),
            np.nan_to_num(np.divide(np.sum(acc[AccIdx.FP]), total_pos)),
            np.nan_to_num(np.divide(np.sum(acc[AccIdx.TN]), total_neg)),
            np.nan_to_num(np.divide(np.sum(acc[AccIdx.FN]), total_neg)),
        )

def get_acc(tp, fp, tn, fn):
    with np.errstate(divide='ignore', invalid='ignore'):
        return np.nan_to_num((tn + tp) / (tn + fp + tp + fn))

def get_f1_score(tp, fp, tn, fn):
    with np.errstate(divide='ignore', invalid='ignore'):
        precision = tp / (tp + fp)
        recall = tp / (tp + fn)
        f1 = (2 * (precision * recall)) / (precision + recall)
        return map(np.nan_to_num, (precision, recall, f1))

In [203]:
per_cat_results = dict()
for i, category in enumerate(all_genres):
    tp, fp, tn, fn = get_acc_info(acc[i])
    curr_acc = get_acc(tp, fp, tn, fn)
    precision, recall, f1 = get_f1_score(tp, fp, tn, fn)
    per_cat_results[category] = (
        '+' + f' {category} '.center(31, '-') + '+\n'
        '|' + f' {all_genres_amnt[category]} instances'.center(31) + '|\n'
        f'| True positive:  {(tp*100):12g}% |\n'
        f'| False positive: {(fp*100):12g}% |\n'
        f'| True negative:  {(tn*100):12g}% |\n'
        f'| False negative: {(fn*100):12g}% |\n'
        f'| Accuracy:       {(curr_acc*100):12g}% |\n'
        f'| Precision:      {(precision*100):12g}% |\n'
        f'| Recall:         {(recall*100):12g}% |\n'
        f'| F1 Score:       {(f1*100):12g}% |\n'
        f'+{"":-^31}+\n'
    )

In [206]:
for _, v in sorted(per_cat_results.items(), key=lambda kv: kv[0]):
    print(v)

+------------ Action -----------+
|         4954 instances        |
| True positive:       76.0766% |
| False positive:      23.9234% |
| True negative:        81.993% |
| False negative:       18.007% |
| Accuracy:            79.0348% |
| Precision:           76.0766% |
| Recall:              80.8606% |
| F1 Score:            78.3957% |
+-------------------------------+

+---------- Adult Cast ---------+
|          536 instances        |
| True positive:             0% |
| False positive:            0% |
| True negative:       97.6105% |
| False negative:      2.38945% |
| Accuracy:            97.6105% |
| Precision:                 0% |
| Recall:                    0% |
| F1 Score:                  0% |
+-------------------------------+

+---------- Adventure ----------+
|         3954 instances        |
| True positive:       91.6667% |
| False positive:      8.33333% |
| True negative:       82.5847% |
| False negative:      17.4153% |
| Accuracy:            87.1257% |
| Precision:

In [212]:
total_pos = np.sum(acc[:, :2])
total_neg = np.sum(acc[:, 2:4])
tp = np.sum(acc[:, AccIdx.TP]) / total_pos
fp = np.sum(acc[:, AccIdx.FP]) / total_pos
tn = np.sum(acc[:, AccIdx.TN]) / total_neg
fn = np.sum(acc[:, AccIdx.FN]) / total_neg

curr_acc = get_acc(tp, fp, tn, fn)
precision, recall, f1 = get_f1_score(tp, fp, tn, fn)
print('+' + f' totals '.center(31, '-') + '+')
print('|' + f' {len(validationData)} instances'.center(31) + '|')
print(f'| True positive:  {(tp*100):12g}% |')
print(f'| False positive: {(fp*100):12g}% |')
print(f'| True negative:  {(tn*100):12g}% |')
print(f'| False negative: {(fn*100):12g}% |')
print(f'| Accuracy:       {(curr_acc*100):12g}% |')
print(f'| Precision:      {(precision*100):12g}% |')
print(f'| Recall:         {(recall*100):12g}% |')
print(f'| F1 Score:       {(f1*100):12g}% |')
print(f'+{"":-^31}+')

+------------ totals -----------+
|         3641 instances        |
| True positive:        61.147% |
| False positive:       38.853% |
| True negative:       96.4841% |
| False negative:       3.5159% |
| Accuracy:            78.8155% |
| Precision:            61.147% |
| Recall:              94.5627% |
| F1 Score:            74.2693% |
+-------------------------------+
