In [36]:
import os
import random
import datetime

import torch
import torchmetrics
import torchsummary
import numpy as np
import torch.nn.functional as F
from torch import nn
from PIL import Image
from torch.utils.data import Dataset
from torchvision.io import read_image
from facenet_pytorch import InceptionResnetV1, MTCNN
from torch.utils.tensorboard import SummaryWriter


from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
from sklearn.model_selection import train_test_split

In [157]:
DATASET_PATH = "data"
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
EMBED_DIM = 32

mtcnn = MTCNN()
resnet = InceptionResnetV1(pretrained='vggface2').eval()

In [158]:
mtcnn.thresholds = [0., 0., 0.]

In [170]:
files_opposite = []
files_target = []

path = DATASET_PATH + "/markup_opposite/"
files_opposite += [path + f for f in os.listdir(path)]
    
path = DATASET_PATH + "/markup_target/"
files_target += [path + f for f in os.listdir(path)]

path = DATASET_PATH + "/opposite/"
files_opposite += [path + f for f in os.listdir(path)]

path = DATASET_PATH + "/target/"
files_target += [path + f for f in os.listdir(path)]

files_opposite = np.array(files_opposite)
files_target = np.array(files_target)

y_opposite = np.zeros_like(files_opposite, dtype='float32')
y_target = np.ones_like(files_target, dtype='float32')

X = np.concatenate([files_opposite, files_target])
Y = np.concatenate([y_opposite, y_target])

Xtrain, Xval, Ytrain, Yval = train_test_split(X, Y, test_size=0.2, random_state=69)

In [171]:
Xtest = []
Ytest = []

path = DATASET_PATH + "/test_target/"
ldir = [path + f for f in os.listdir(path)]
Xtest+= ldir
Ytest+= [1. for _ in ldir]
    
path = DATASET_PATH + "/test_opposite/"
ldir = [path + f for f in os.listdir(path)]
Xtest+= ldir
Ytest+= [1. for _ in ldir]

Xtest = np.array(Xtest)
Ytest = np.array(Ytest, dtype='float32')

In [172]:
# Xembs = {}
# with torch.no_grad():
#     for _x in [Xtrain, Xval, Xtest]:
#         for x_path in _x:
#             img = Image.open(x_path)
#             img = mtcnn(img)
#             Xembs[x_path] = resnet(img.unsqueeze(0)).numpy()
            
# torch.save(Xembs, 'InceptionResnetV1_vggface2.dict')
# Xembs = torch.load('InceptionResnetV1_vggface2.dict')

In [173]:
embeds = np.concatenate([Xembs[x_path] for x_path in Xtrain])

In [174]:
pca = PCA(n_components=EMBED_DIM)
pca.fit(embeds)

In [175]:
class EmbedDataset(Dataset):
    def __init__(self, x, y, embeds,
                 decompositor=None,
                 **kwargs):
        assert len(x) == len(y)
        
        self.embeds = embeds
        self.x_opposite = x[y == 0.]
        self.x_target = x[y == 1.]
        self.y_opposite = y[y == 0.]
        self.y_target = y[y == 1.]
    
        self.decompositor = decompositor
        
        
    def __len__(self):
        return len(self.y_opposite) + len(self.y_target)
    
    
    def __getitem__(self, _):
        if random.random() > 0.5:
            target = False
            idx = random.choice(range(len(self.x_opposite)))
            x_path, y = self.x_opposite[idx], self.y_opposite[idx]
            
        else:
            target = True
            idx = random.choice(range(len(self.x_target)))
            x_path, y = self.x_target[idx], self.y_target[idx]
            
            
        x = self.embeds[x_path]
        if self.decompositor:
            x = self.decompositor(x).astype('float32')
        
        return x[0], y[None, ...]

In [176]:
class TestDataset(Dataset):
    def __init__(self, x, y, embeds,
                 decompositor=None,
                 **kwargs):
        assert len(x) == len(y)
        
        self.x = x
        self.y = y

        self.embeds = embeds
        self.decompositor = decompositor
        
        
    def __len__(self):
        return len(self.y)
    
    
    def __getitem__(self, idx):
        y = self.y[idx]
        x = self.embeds[self.x[idx]]
        
        if self.decompositor:
            x = self.decompositor(x).astype('float32')
            
        return x[0], y[None, ...]

In [177]:
class Model(nn.Module):
    def __init__(self, input_dim, d=[32, 32], **kwargs):
        super().__init__()
        
        seq = []
        d = [input_dim] + d + [1]
        for i in range(len(d)-1):
            seq.append(
                nn.Linear(d[i], d[i+1])
            )
            seq.append(nn.Dropout(p=0.5))
            if i != len(d)-2:
                seq.append(nn.GELU())
                
        self.seq = nn.Sequential(*seq)

        
    def forward(self, x: torch.Tensor):
        return self.seq(x)

In [178]:
class Trainer:
    def __init__(self, 
        model, loss_fn, optimizer, 
        stop_batch, metric=None,
        device='cuda', fp16=False, 
        **kwargs):
        self.model: nn.Module = model
        self.device = device
        self.metric = metric
        
        self.stop_batch = stop_batch
        self.loss_fn = loss_fn
        self.optimizer = optimizer  
        
        self.fp16 = fp16
        if fp16:
            self.scaler = torch.cuda.amp.GradScaler()        
        
        
    def checkpoint(self) -> dict:
        cpoint =  {
            "model": self.model.state_dict(),
            "optimizer": self.optimizer.state_dict(),
        }
        if self.fp16:
            cpoint["scaler"] = self.scaler.state_dict()
            
        return cpoint
    
        
    def train(self, dataset, epoch) -> float:
        self.model.train()
        running_loss = 0
        for idx, batch in enumerate(dataset):
            X, Y = batch
            X, Y = X.to(self.device), Y.to(self.device)
            running_loss+= self.__train(X, Y)
            if idx >= self.stop_batch:
                break
                
        return running_loss
            
            
    def val(self, dataset) -> list[torch.Tensor]:
        self.model.eval()
        val_pred, val_true = [], []
        with torch.inference_mode():
            for batch in dataset:
                X, Y = batch
                val_pred+= [self.model(X.to(self.device)).cpu()]
                val_true+= [Y]
        return torch.cat(val_pred), torch.cat(val_true)
                
            
    def __train(self, X, Y) -> float:
        self.optimizer.zero_grad()
        if self.fp16:
            with torch.cuda.amp.autocast(enabled=True):
                outputs = self.model(X)
                loss = self.loss_fn(outputs, Y)
                
            self.scaler.scale(loss).backward()
            self.scaler.step(self.optimizer)
            self.scaler.update()

        else:
            outputs = self.model(X)
            loss = self.loss_fn(outputs, Y)    
            loss.backward()
            self.optimizer.step()
            
        if self.metric:
            self.metric(outputs.sigmoid(), Y.int())
            
        return loss.item()

In [179]:
trainLoader = torch.utils.data.DataLoader(
    EmbedDataset(
        x=Xtrain, 
        y=Ytrain, 
        embeds=Xembs,
        decompositor=pca.transform), 
    batch_size=2048, 
)

In [180]:
valLoader = torch.utils.data.DataLoader(
    TestDataset(
        x=Xval, 
        y=Yval, 
        embeds=Xembs,
        decompositor=pca.transform), 
    batch_size=2048, 
)

In [181]:
testLoader = torch.utils.data.DataLoader(
    TestDataset(
        x=Xtest, 
        y=Ytest, 
        embeds=Xembs,
        decompositor=pca.transform), 
    batch_size=2048, 
)

In [185]:
model = Model(EMBED_DIM, d=[32, 32])
trainer = Trainer(
    model=model.cuda(),
    stop_batch=10_000/2048,
    metric=torchmetrics.AUROC(),
    loss_fn=nn.BCEWithLogitsLoss(reduce=True),
    optimizer=torch.optim.Adam(model.parameters(), lr=3e-4),
)

acc = torchmetrics.Accuracy()
auc = torchmetrics.AUROC()



In [186]:
name = 'InceptionResnetV1 vggface2 pca 64'
board_name = name + datetime.datetime.now().strftime("%Y.%m.%d - %H-%M-%S")

log_dir = f"logs/fit/{board_name}"
writer = SummaryWriter(log_dir)

In [187]:
try:
    wait = 0
    patience = 50
    
    epoch = 0
    best_loss = -np.inf
    while wait < patience:
        train_loss = trainer.train(trainLoader, epoch)

        val_pred, val_true = trainer.val(valLoader)
        metrics = {
            'AUC': auc(val_pred.sigmoid(), val_true.int()),
            'ACC': acc(val_pred.sigmoid(), val_true.int()),
        }
        writer.add_scalar('Loss/train', train_loss, epoch)
        writer.add_scalar('AUC/train', trainer.metric.compute(), epoch)
        writer.add_scalar('AUC/val', metrics['AUC'], epoch)
        writer.add_scalar('ACC/val', metrics['ACC'], epoch)


        wait+=1
        epoch+=1
        if metrics['AUC'] > best_loss:
            checkpoint = trainer.checkpoint()
            torch.save(checkpoint, f'models/w/{name}.torch')
            best_loss = metrics['AUC']
            wait = 0


except KeyboardInterrupt:
    print("KeyboardInterrupt")

In [188]:
test_pred, test_true = trainer.val(testLoader)

In [195]:
import matplotlib.pyplot as plt


plt.style.use('seaborn')

In [200]:
import pickle

In [201]:
pickle.dump(pca, open("models/pca.pkl","wb"))

In [None]:
for conf, label, pic in zip(val_pred, val_true, Xval):
    print(conf.sigmoid(), label)
    img = Image.open(pic)
    plt.imshow(img)
    plt.grid(False)
    plt.show()