In [None]:
import multiprocessing
import os
import pickle
import sys
import typing
import zipfile
from collections import Counter

import cv2
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy
import torch
import torch.nn.functional as F
from PIL import Image
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score
from sklearn.model_selection import train_test_split
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
from torchvision.models import resnet
from tqdm.auto import tqdm

tqdm.pandas()


print(f"""{sys.version=}""")
print(f"""{pd.__version__=}""")
print(f"""{np.__version__=}""")

In [None]:
torch.set_num_threads(32)
torch.set_num_interop_threads(32)

In [None]:
root = "/home/asciishell/s3/jupyter.asciishell.ru"

In [None]:
df = pd.read_csv(f"{root}/train.csv")[["id", "glasses"]].copy()

In [None]:
df_train, df_valid = train_test_split(df, test_size=0.3, random_state=42)

In [None]:
train_transform = transforms.Compose(
    [
        transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
        transforms.RandomGrayscale(p=0.2),
        transforms.ToTensor(),
        transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),
    ]
)

test_transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),
    ]
)


class FaseDataset(torch.utils.data.Dataset):
    def __init__(
        self,
        root: str,
        size: tuple[int, int],
        sample: pd.DataFrame,
        transform: typing.Callable,
    ):
        self.root = root
        self.size = size
        self.ids = sample["id"].values
        self.targets = sample["glasses"].values
        self.transform = transform

    def __len__(self):
        return len(self.ids)

    def __getitem__(self, index):
        img, target = self.ids[index], self.targets[index]
        img = Image.open(self.root.format(img))
        img = img.resize(self.size, Image.Resampling.BILINEAR)

        pos_1 = self.transform(img)
        # pos_2 = self.transform(img)

        # return pos_1, pos_2, target
        return pos_1, target

In [None]:
batch_size = 128
im_size = (224, 224)

In [None]:
train_loader = torch.utils.data.DataLoader(
    FaseDataset(f"{root}/faces-spring-2020/faces-spring-2020/face-{{}}.png", im_size, df_train, train_transform),
    batch_size=batch_size,
    shuffle=True,
    num_workers=4,
    pin_memory=True,
    drop_last=True,
)
test_loader = torch.utils.data.DataLoader(
    FaseDataset(f"{root}/faces-spring-2020/faces-spring-2020/face-{{}}.png", im_size, df_valid, test_transform),
    batch_size=batch_size,
    shuffle=True,
    num_workers=4,
    pin_memory=True,
    drop_last=True,
)

In [None]:
class Model(torch.nn.Module):
    def __init__(self, feature_dim=128, arch="resnet18", init=True):
        super(Model, self).__init__()

        self.f = []

        if arch == "resnet18":
            w = resnet.ResNet18_Weights.DEFAULT if init else None
            module = resnet.resnet18(weights=w)
            in_size = 512
        elif arch == "resnet34":
            w = resnet.ResNet34_Weights.DEFAULT if init else None
            module = resnet.resnet34(weights=w)
            in_size = 512
        elif arch == "resnet50":
            w = resnet.ResNet50_Weights.DEFAULT if init else None
            module = resnet.resnet50(weights=w)
            in_size = 2048
        else:
            raise Exception("Unknown module {}".format(repr(arch)))
        for name, module in module.named_children():
            # if name == "conv1":
            #     module = torch.nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
            # if not isinstance(module, torch.nn.Linear) and not isinstance(module, torch.nn.MaxPool2d):
            if not isinstance(module, torch.nn.Linear):
                self.f.append(module)
        # encoder
        self.f = torch.nn.Sequential(*self.f)
        # projection head
        self.g = torch.nn.Linear(in_size, feature_dim, bias=True)
        # self.g = torch.nn.Sequential(
        #     torch.nn.Linear(in_size, 512, bias=False),
        #     torch.nn.BatchNorm1d(512),
        #     torch.nn.ReLU(inplace=True),
        #     torch.nn.Linear(512, feature_dim, bias=True),
        # )

    def forward(self, x):
        x = self.f(x)
        feature = torch.flatten(x, start_dim=1)
        out = self.g(feature)
        return out
        # return F.normalize(out, dim=-1)


class ContrastiveLoss(torch.nn.Module):
    def __init__(self, temperature, cuda):
        super().__init__()
        self.temperature = temperature
        self.cuda = cuda

    def get_negative_mask(self, batch_size):
        negative_mask = torch.ones((batch_size, 2 * batch_size), dtype=bool)
        for i in range(batch_size):
            negative_mask[i, i] = 0
            negative_mask[i, i + batch_size] = 0

        negative_mask = torch.cat((negative_mask, negative_mask), 0)
        return negative_mask

    def forward(self, out_1, out_2):
        batch_size = out_1.shape[0]

        # neg score
        out = torch.cat([out_1, out_2], dim=0)
        neg = torch.exp(torch.mm(out, out.t().contiguous()) / self.temperature)
        mask = self.get_negative_mask(batch_size)
        if self.cuda:
            mask = mask.cuda()
        neg = neg.masked_select(mask).view(2 * batch_size, -1)

        # pos score
        pos = torch.exp(torch.sum(out_1 * out_2, dim=-1) / self.temperature)
        pos = torch.cat([pos, pos], dim=0)

        # estimator g()
        Ng = neg.sum(dim=-1)

        # contrastive loss
        loss = (-torch.log(pos / (pos + Ng))).mean()

        return loss

In [None]:
def main(*, model, criterion, optimizer, writer, train_loader, valid_loader, cuda=True, epochs=200):
    if cuda:
        model = model.cuda()
    model = torch.nn.DataParallel(model)
    step = 0
    for epoch in range(1, epochs + 1):
        model.train()
        total_loss, total_num = 0.0, 0
        for pos_1, target in tqdm(train_loader, desc=f"Train {epoch}"):
            if cuda:
                pos_1 = pos_1.cuda(non_blocking=True)
                target = target.cuda(non_blocking=True)
            out_1 = model(pos_1)

            loss = criterion(out_1, target)
            writer.add_scalar("loss/train", loss, step)
            step += 1

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_num += batch_size
            total_loss += loss.item() * batch_size

        train_loss = total_loss / total_num
        if epoch % 1 == 0:
            with torch.no_grad():
                model.eval()
                predicts = []
                targets = []
                for pos_1, target in tqdm(valid_loader, desc=f"Validation {epoch}"):
                    predict = model(pos_1)
                    predicts.extend(predict.cpu().numpy())
                    targets.extend(target.cpu().numpy())
            predicts = np.stack(predicts)
            targets = np.stack(targets)
            writer.add_scalar("valid/f1w", f1_score(targets, predicts.argmax(axis=1), average="weighted"), epoch)
            writer.add_scalar("valid/acc", accuracy_score(targets, predicts.argmax(axis=1)), epoch)
            writer.add_scalar(
                "valid/roc_auc",
                roc_auc_score(targets, scipy.special.softmax(predicts, axis=1)[:, 1]),
                epoch,
            )
        writer.flush()
    writer.close()

In [None]:
out = "exp2"
# os.mkdir(out)
model = Model(2, "resnet18", True)
# criterion = ContrastiveLoss(0.5, False)
criterion = torch.nn.CrossEntropyLoss()
model = torch.nn.DataParallel(model)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-6)

writer = SummaryWriter(out)
main(
    model=model,
    criterion=criterion,
    optimizer=optimizer,
    writer=writer,
    train_loader=train_loader,
    valid_loader=test_loader,
    cuda=False,
    epochs=30,
)