In [1]:
import torch
import torch.nn as nn
import numpy as np
import torchvision
from torchvision import models
import matplotlib.pyplot as plt
import pandas as pd
from PIL import Image

### Data Prep

In [2]:
img_data_dir = "/media/curttigges/project-files/datasets/danbooru/"

#### Data Examination

In [3]:
df = pd.read_pickle(img_data_dir+'image_subset_metadata.pkl')
dfs = df.sample(50000)
dfs.head()

FileNotFoundError: [Errno 2] No such file or directory: '/media/curttigges/project-files/datasets/danbooru/image_subset_metadata.pkl'

In [4]:
im = Image.open(img_data_dir+dfs.iloc[5,3])
im

NameError: name 'dfs' is not defined

In [16]:
im.size

(512, 512)

In [14]:
dfs.head()

Unnamed: 0,id,tags,image_available,paths,file_names,tagslist
8888,688054,"[{'id': '540830', 'name': '1boy', 'category': ...",1.0,danbooru-images/danbooru-images/0054/688054.jpg,688054.jpg,"[1boy, 1girl, a-i-red, asbel_lhant, bad_id, ba..."
65854,472091,"[{'id': '470575', 'name': '1girl', 'category':...",1.0,danbooru-images/danbooru-images/0091/472091.jpg,472091.jpg,"[1girl, akudama_geku, blue_eyes, blue_hair, ca..."
46820,323114,"[{'id': '540830', 'name': '1boy', 'category': ...",1.0,danbooru-images/danbooru-images/0114/323114.jpg,323114.jpg,"[1boy, 1girl, bad_id, bad_pixiv_id, bikini, bl..."
115006,2098014,"[{'id': '1821', 'name': '2girls', 'category': ...",1.0,danbooru-images/danbooru-images/0014/2098014.jpg,2098014.jpg,"[2girls, alternate_costume, bat_wings, breasts..."
46134,235119,"[{'id': '470575', 'name': '1girl', 'category':...",1.0,danbooru-images/danbooru-images/0119/235119.jpg,235119.jpg,"[1girl, animal_ears, baku_taso, bunny_ears, ch..."


In [7]:
def get_tag_dict(df, n=10):
    tag_dict = {}
    for list in df.tagslist:
        for tag in list:
            if tag in tag_dict:
                tag_dict[tag] += 1
            else:
                tag_dict[tag] = 1

    item = sorted(tag_dict.items(), key = lambda x:x[1],reverse = True)
    print(f"{n} top tags:")
    for i in range(0,n):
        print(item[i])

    return tag_dict

In [8]:
def get_top_tags(tag_dict, min_support):
    reduced_dict = {t for t in tag_dict if tag_dict[t] > min_support}
    print(f"Dictionary contains {len(reduced_dict)} tags.")
    tag_list = [l for l in reduced_dict]
    return tag_list

In [9]:
tags = get_tag_dict(dfs, 10)
classes = get_top_tags(tags, 100)

10 top tags:
('1girl', 33055)
('solo', 28992)
('long_hair', 21941)
('highres', 16827)
('smile', 14957)
('short_hair', 14436)
('multiple_girls', 12173)
('blush', 11867)
('touhou', 10975)
('looking_at_viewer', 10757)
Dictionary contains 1345 tags.


In [17]:
class_map = {n:c for n, c in enumerate(classes)}
class_map

{0: 'neptune_(series)',
 1: 'bloomers',
 2: 'check_translation',
 3: 'detached_collar',
 4: 'strapless_dress',
 5: 'skirt',
 6: 'wallpaper',
 7: 'black_bow',
 8: 'eyebrows',
 9: 'mahou_shoujo_lyrical_nanoha_strikers',
 10: 'artist_self-insert',
 11: 'white_hair',
 12: 'full_moon',
 13: 'overwatch',
 14: 'touken_ranbu',
 15: 'halterneck',
 16: 'green_background',
 17: 'valentine',
 18: 'animal_print',
 19: 'faceless',
 20: 'hood_down',
 21: 'covering',
 22: 'white_flower',
 23: 'drooling',
 24: 'cameltoe',
 25: 'sign',
 26: 'sitting',
 27: 'witch_hat',
 28: 'helmet',
 29: 'shiny_hair',
 30: 'headset',
 31: 'red_neckwear',
 32: 'kamishirasawa_keine',
 33: 'tate_eboshi',
 34: 'double_bun',
 35: 'traditional_media',
 36: 'brown_footwear',
 37: 'pink_eyes',
 38: 'name_tag',
 39: 'rose',
 40: 'back',
 41: 'thighhighs',
 42: 'side_braid',
 43: 'guitar',
 44: 'lolita_fashion',
 45: 'dated',
 46: 'visor_cap',
 47: 'bow_(weapon)',
 48: 'witch',
 49: 'multiple_tails',
 50: 'sisters',
 51: 'tenryu

In [10]:
from data.danbooru_dataset import DanbooruDataset
ds = DanbooruDataset(dfs, classes, img_data_dir, transforms=None)

In [15]:
img, labels = ds.__getitem__(20)

In [None]:
def imshow(inp, title=None):
    """Imshow for Tensor."""
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    plt.imshow(inp)
    if title is not None:
        plt.title(title)
    plt.pause(0.001)  # pause a bit so that plots are updated


# Get a batch of training data
inputs, classes = next(iter(train_dl))

# Make a grid from batch
out = torchvision.utils.make_grid(inputs)

imshow(out)

#### Data Preparation

In [3]:
from data.danbooru_data_module import DanbooruDataModule

BATCH_SIZE = 64

dm = DanbooruDataModule(
    img_data_dir+'image_subset_metadata.pkl',
    img_data_dir,
    subsample=50000,
    min_support=1000,
    batch_size=BATCH_SIZE,
    num_workers=12)

In [4]:
dm.setup()

10 top tags:
('1girl', 32750)
('solo', 28658)
('long_hair', 21902)
('highres', 16880)
('smile', 14962)
('short_hair', 14176)
('multiple_girls', 12272)
('blush', 11758)
('touhou', 10835)
('open_mouth', 10529)
Dictionary contains 188 tags.


In [5]:
dm.num_classes

188

### Models

In [6]:
import torch
import torch.nn.functional as F
import pytorch_lightning as pl

from torch.optim.lr_scheduler import OneCycleLR
from torchmetrics.functional import accuracy, precision
import torchmetrics.functional as tf
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import TQDMProgressBar

In [7]:
class ResNetMultiLabel(nn.Module):
    def __init__(self, model, n_classes):
        super().__init__()
        resnet = model
        resnet.fc = nn.Sequential(
            nn.Dropout(p=0.2),
            nn.Linear(in_features=resnet.fc.in_features, out_features=n_classes)
        )
        self.backbone = resnet
        #self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        
        out = self.backbone(x)
        
        return out

In [8]:
class ResNetMultiTrainModule(pl.LightningModule):
    def __init__(self, model, model_desc, batch_size, learning_rate, momentum, n_classes, thresh=0.5):
        super().__init__()

        # Key parameters
        self.save_hyperparameters()

        self.model = ResNetMultiLabel(model, n_classes)

    def forward(self, x):
        x = self.model(x)        
        return x

    def evaluate(self, batch, stage=None):
        x, y = batch
        y_hat = self(x)
        loss = F.binary_cross_entropy_with_logits(y_hat, y.type(torch.float))
        acc = accuracy(y_hat, y.type(torch.int), threshold=self.hparams.thresh)
        prec = precision(y_hat, y.type(torch.int), threshold=self.hparams.thresh)
        recall = tf.recall(y_hat, y.type(torch.int), threshold=self.hparams.thresh)
        f1_score = tf.f1_score(y_hat, y.type(torch.int), threshold=self.hparams.thresh)
        rmap = tf.retrieval_average_precision(y_hat, y.type(torch.int))

        if stage:
            self.log(f"{stage}_loss", loss, prog_bar=True)
            self.log(f"{stage}_acc", acc, prog_bar=True)
            self.log(f"{stage}_prec", prec, prog_bar=True)
            self.log(f"{stage}_recall", recall, prog_bar=True)
            self.log(f"{stage}_f1_score", f1_score, prog_bar=True)
            self.log(f"{stage}_rmap", rmap, prog_bar=True)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.binary_cross_entropy_with_logits(y_hat, y.type(torch.float))
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        self.evaluate(batch, "val")

    def test_step(self, batch, batch_idx):
        self.evaluate(batch, "test")

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(
            self.parameters(), 
            lr=self.hparams.learning_rate,
            betas=(0.9,0.999))
        #optimizer = torch.optim.SGD(self.parameters(), lr=self.hparams.learning_rate, momentum=self.hparams.momentum)
        
        # steps_per_epoch = 60000 // self.hparams.batch_size
        # '''
        # lr_scheduler_dict = {
        #     "scheduler":MultiStepLR(optimizer, milestones=[100, 150], gamma=0.1)
        # }
        # '''
        # lr_scheduler_dict = {
        #     "scheduler":OneCycleLR(
        #         optimizer,
        #         self.hparams.learning_rate,
        #         epochs=self.trainer.max_epochs,
        #         steps_per_epoch=steps_per_epoch,
        #         anneal_strategy='cos'
        #     ),
        #     "interval":"step",
        # }
        #return {"optimizer":optimizer, "lr_scheduler":lr_scheduler_dict}
        return optimizer

In [9]:
pl_model = ResNetMultiTrainModule(
    models.resnet50(pretrained=True),
    model_desc="resnet50",
    batch_size=BATCH_SIZE, 
    learning_rate=0.001, 
    momentum=0.9, 
    n_classes=dm.num_classes)

  f"Attribute {k!r} is an instance of `nn.Module` and is already saved during checkpointing."


In [10]:
wandb_logger = WandbLogger(project="resnet-danbooru")
wandb_logger.watch(pl_model, log="all")

trainer = pl.Trainer(
    max_epochs=10,
    precision=16,
    accelerator='gpu', 
    devices=1,
    logger=wandb_logger,
    profiler="simple",
    callbacks=[TQDMProgressBar(refresh_rate=10)])
trainer.fit(pl_model, dm)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mcurt-tigges[0m ([33mascendant[0m). Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`
Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


10 top tags:
('1girl', 32750)
('solo', 28658)
('long_hair', 21902)
('highres', 16880)
('smile', 14962)
('short_hair', 14176)
('multiple_girls', 12272)
('blush', 11758)
('touhou', 10835)
('open_mouth', 10529)
Dictionary contains 188 tags.


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type             | Params
-------------------------------------------
0 | model | ResNetMultiLabel | 23.9 M
-------------------------------------------
23.9 M    Trainable params
0         Non-trainable params
23.9 M    Total params
47.786    Total estimated model params size (MB)


Epoch 9: 100%|██████████| 692/692 [26:20<00:00,  2.28s/it, loss=0.219, v_num=1mg0, val_loss=0.268, val_acc=0.928, val_prec=0.624, val_recall=0.0606, val_f1_score=0.110, val_rmap=0.262]   


FIT Profiler Report

------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|  Action                                                                                                                                                                                                    	|  Mean duration (s)	|  Num calls      	|  Total time (s) 	|  Percentage %   	|
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|  Total                                                                       