In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import transforms, datasets
import torch.autograd as autograd
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import numpy as np 
from torchvision.transforms import ToTensor
import pandas as pd
import os
from torchvision.datasets import ImageFolder 
from torch.utils.data import Dataset, ConcatDataset, DataLoader
from sklearn.model_selection import train_test_split
from collections import Counter
from sklearn.metrics import confusion_matrix, f1_score, roc_curve, roc_auc_score, recall_score, precision_score

In [None]:
USE_CUDA = torch.cuda.is_available()
DEVICE = torch.device("cuda" if USE_CUDA else "cpu")
print("Device : ", DEVICE)

In [None]:
EPOCHS = 100
BATCH_SIZE = 16
lr = 1e-4
resize = 128
latent_dim = 256

In [None]:
class Img_Encoder(nn.Module):
    def __init__(self, latent_dim):
        super(Img_Encoder, self).__init__()
        self.latent_dim = latent_dim
        
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 128, 3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 256, 3, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
        )
        self.fc = nn.Linear(self.latent_dim, self.latent_dim)
        
    def forward(self, x):
        x = self.encoder(x)
        x = x.view(-1, latent_dim)
        x = self.fc(x)
        return x

In [None]:
class MLP_Encoder(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(MLP_Encoder, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = x.view(-1, latent_dim)
        return x

In [None]:
class Dual_Encoder(nn.Module):
    def __init__(self, encoder1, encoder2):
        super(Dual_Encoder, self).__init__()
        self.encoder1 = encoder1
        self.encoder2 = encoder2

    def forward(self, x1, x2):
        latent1 = self.encoder1(x1)
        latent2 = self.encoder2(x2)
        return latent1, latent2

In [None]:
class Discriminator(nn.Module):
    def __init__(self, latent_dim=latent_dim):
        super(Discriminator, self).__init__()
        self.latent_dim = latent_dim

        self.model = nn.Sequential(
            nn.Linear(self.latent_dim, 128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(128, 64),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(64, 1),
            nn.Sigmoid()
        )

    def forward(self, z):
        validity = self.model(z)
        return validity

In [None]:
class CrossAttention(nn.Module):
    def __init__(self, img_feat_dim, word_feat_dim, attn_dim):
        super(CrossAttention, self).__init__()
        
        self.query_linear = nn.Linear(img_feat_dim, attn_dim)
        self.key_linear = nn.Linear(word_feat_dim, attn_dim)
        self.value_linear = nn.Linear(word_feat_dim, attn_dim)

    def forward(self, img_feat, word_feat):

        query = self.query_linear(img_feat) 

        key = self.key_linear(word_feat)  
        value = self.value_linear(word_feat) 

        attn_scores = torch.bmm(key.unsqueeze(1), query.unsqueeze(2)).squeeze(2) 
        attn_weights = F.softmax(attn_scores, dim=1)

        context_vector = torch.bmm(value.unsqueeze(1).transpose(1, 2), attn_weights.unsqueeze(2)).squeeze(2)  
        
        return context_vector

In [None]:
def cross_attention(img_latent, mlp_latent):
    
    cross_attn = CrossAttention(img_feat_dim = latent_dim, word_feat_dim = latent_dim, attn_dim = latent_dim).to(DEVICE)
    
    final_latent = cross_attn(img_latent, mlp_latent)
    
    return final_latent

In [None]:
def make_fake(real_latent):
    
    latent_mean = torch.mean(real_latent, dim=0)
    latent_std = torch.std(real_latent, dim=0)
    
    noise = torch.randn_like(real_latent) * latent_mean + latent_std
    fake_latent = real_latent + (noise - noise.mean()) * (latent_std / noise.std())
    
    return fake_latent

In [None]:
def concat_data(transformer, image_folder_path, excel_file_path):
    
    image_dataset = ImageFolder(image_folder_path, transform=transformer)
    excel_dataset = pd.read_excel(excel_file_path)
    
    image_filenames = [os.path.basename(path) for path, _ in image_dataset.imgs]
    image_filenames_slice = [i[:-4] for i in image_filenames]
    excel_filenames = excel_dataset["Index"].tolist()
    common_filenames = list(set(image_filenames_slice).intersection(set(excel_filenames)))
    common_filenames = [i for i in image_filenames if i[:-4] in common_filenames]
    
    datasets = []
    for filename in common_filenames:
        
        image_index = image_filenames.index(filename)
        image_path = os.path.join(image_folder_path, filename)
        image_tensor = image_dataset[image_index][0]

        excel_index = excel_filenames.index(filename[:-4])
        quarter_data = excel_dataset["air_quarter"].iloc[excel_index]
        subtract_data = excel_dataset["pta"].iloc[excel_index]
        filename = excel_dataset["Index"].iloc[excel_index]
        
        pta_data = torch.tensor([quarter_data, subtract_data])

        datasets.append((image_tensor, pta_data, filename))
             
    return datasets

In [None]:
def save_e(model, epoch, train_acc, test_acc):
    path = "/encoder_{}_{}_{}.pt".format(epoch, train_acc, test_acc)
    torch.save(model, path)
    
def save_d(model, epoch, train_acc, test_acc):
    path = "/discriminator_{}_{}_{}.pt".format(epoch, train_acc, test_acc)
    torch.save(model, path)

In [None]:
transformer = transforms.Compose([
            transforms.Resize((resize, resize)),
            transforms.ToTensor()
])

train_image_path = '/home/data/'
train_excel_path = '/home/C-PTA_final_ver3.xlsx'

datasets = concat_data(transformer, train_image_path, train_excel_path)

train_set, test_set = train_test_split(datasets, test_size = 0.2, shuffle = False)
print(len(train_set))
print(len(test_set))

train_loader = torch.utils.data.DataLoader(train_set, batch_size = BATCH_SIZE, shuffle = True, num_workers = 64)
test_loader = torch.utils.data.DataLoader(test_set, batch_size = BATCH_SIZE, shuffle = True, num_workers = 64)

In [None]:
test_excel_path = '/home/I-PTA_final_ver3.xlsx'
test_abnormal = concat_data(transformer, train_image_path, test_excel_path)
print(len(test_abnormal))
abnormal_dataloader = torch.utils.data.DataLoader(test_abnormal, batch_size=BATCH_SIZE, shuffle=True, num_workers=64, drop_last=False)

In [None]:
img_encoder = Img_Encoder(latent_dim=latent_dim).to(DEVICE)
output_vector = int((resize*(1/16))**2)*latent_dim
mlp_encoder = MLP_Encoder(input_size=2, hidden_size=resize, output_size=output_vector).to(DEVICE)
dual_encoder = Dual_Encoder(img_encoder, mlp_encoder).to(DEVICE)
discriminator =  Discriminator().to(DEVICE)

optimizer_discriminator =  optim.AdamW(discriminator.parameters(), lr=lr)
optimizer_img_encoder = optim.AdamW(img_encoder.parameters(), lr=lr)
optimizer_mlp_encoder = optim.AdamW(mlp_encoder.parameters(), lr=lr)
optimizer_dual_encoder = optim.AdamW(dual_encoder.parameters(), lr=lr)

criterion = nn.BCELoss()

In [None]:
def train(epoch, model, discriminator, train_loader):
    
    model.train()
    discriminator.train()
    
    real_error = 0
    fake_error = 0
    
    real_acc = 0
    fake_acc = 0
    
    real = 0
    fake = 0
    train_acc = 0
    
    for batch_idx, (images, pta, filename) in enumerate(train_loader, 1):
        
        image_data = images.to(DEVICE)
        mlp_data = pta.to(DEVICE)
        mlp_data = torch.tensor(mlp_data, dtype=torch.float32)
        
        optimizer_dual_encoder.zero_grad()
        optimizer_discriminator.zero_grad()

        img_latent_variable = model.encoder1(image_data)
        mlp_latent_variable = model.encoder2(mlp_data)
        
        real_latent = cross_attention(img_latent_variable, mlp_latent_variable)
        fake_latent = make_fake(real_latent)
        
        real_pred = discriminator(real_latent)
        real_label = torch.ones_like(real_pred)
        
        fake_pred = discriminator(fake_latent)
        fake_label = torch.zeros_like(fake_pred)
        
        error_num = int(real_pred.size()[0]/len(filename))
        real_error_list = [torch.mean(real_pred[i:i+error_num]).item() for i in range(0, real_pred.size()[0], error_num)]
        return_list = [{filename[i]: real_error_list[i]} for i in range(len(real_error_list)) if real_error_list[i]<=0.5]
        
        real = real + len([{filename[i]: real_error_list[i]} for i in range(len(real_error_list)) if real_error_list[i]])
        train_acc = train_acc + len([{filename[i]: real_error_list[i]} for i in range(len(real_error_list)) if real_error_list[i]>0.5])
        real_acc = real_acc + len([{filename[i]: real_error_list[i]} for i in range(len(real_error_list)) if real_error_list[i]>0.5])
        real_error = real_error + len(return_list)
        
        error_num = int(fake_pred.size()[0]/len(filename))
        fake_error_list = [torch.mean(fake_pred[i:i+error_num]).item() for i in range(0, fake_pred.size()[0], error_num)]
        
        fake = fake + len([{filename[i]: fake_error_list[i]} for i in range(len(fake_error_list)) if fake_error_list[i]])
        fake_acc = fake_acc + len([{filename[i]: fake_error_list[i]} for i in range(len(fake_error_list)) if fake_error_list[i]<=0.5])
        train_acc = train_acc + len([{filename[i]: fake_error_list[i]} for i in range(len(fake_error_list)) if fake_error_list[i]])
        fake_error = fake_error + len([{filename[i]: fake_error_list[i]} for i in range(len(fake_error_list)) if fake_error_list[i]>0.5])

        loss_real = criterion(real_pred, real_label)
        loss_fake = criterion(fake_pred, fake_label)

        final_loss = (loss_real + loss_fake)/2
            
        final_loss.backward(retain_graph = True)

        optimizer_discriminator.step()
        
        encoder_loss = criterion(fake_pred, real_label)       
        
        optimizer_discriminator.zero_grad()
        real_pred = discriminator(real_latent)
        real_label = torch.ones_like(real_pred)
        fake_pred = discriminator(fake_latent)
        fake_label = torch.zeros_like(fake_pred)
        loss_real = criterion(real_pred, real_label)
        loss_fake = criterion(fake_pred, fake_label)
        final_loss = (loss_real + loss_fake)/2
        final_loss.backward()
        optimizer_discriminator.step()

        optimizer_dual_encoder.zero_grad()

        e_loss_clone = encoder_loss.clone().detach().requires_grad_(True)
            
        e_loss_clone.backward()
            
        optimizer_dual_encoder.step()
            
    return train_acc/(real+fake)*100, real_acc/real*100, fake_acc/fake*100, real_error/real*100, fake_error/fake*100, return_list

In [None]:
def test(epoch, model, discriminator, normal_loader, abnormal_loader):
    
    with torch.no_grad():
        
        model.eval()
        discriminator.eval()
        
        normal = 0
        abnormal = 0
        
        normal_error = 0
        abnormal_error = 0
        
        normal_acc = 0
        abnormal_acc = 0
        
        test_acc = 0
        
        return_normal = {}
        return_abnormal = {}
        
        labels = []
        preds = []
        
        for batch_idx, (images, pta, filename) in enumerate(normal_loader, 1):
        
            image_data = images.to(DEVICE)
            mlp_data = pta.to(DEVICE)
            mlp_data = torch.tensor(mlp_data, dtype=torch.float32)

            img_latent_variable = model.encoder1(image_data)
            mlp_latent_variable = model.encoder2(mlp_data)

            normal_latent = cross_attention(img_latent_variable, mlp_latent_variable)

            normal_pred = discriminator(normal_latent)
            
            error_num = int(normal_pred.size()[0]/len(filename))
            error_list = [torch.mean(normal_pred[i:i+error_num]).item() for i in range(0, normal_pred.size()[0], error_num)]

            if not return_normal:
                return_normal = {filename[i]: error_list[i] for i in range(len(error_list)) if error_list[i]<=0.5}
            else:
                return_normal.update({filename[i]: error_list[i] for i in range(len(error_list)) if error_list[i]<=0.5})
            
            labels.extend([1] * normal_pred.size(0))
            preds.extend(normal_pred.cpu().numpy())
            
            normal = normal + len([{filename[i]: error_list[i]} for i in range(len(error_list)) if error_list[i]])
            test_acc = test_acc + len([{filename[i]: error_list[i]} for i in range(len(error_list)) if error_list[i]>0.5])
            normal_acc = normal_acc + len([{filename[i]: error_list[i]} for i in range(len(error_list)) if error_list[i]>0.5])
            normal_error = normal_error + len({filename[i]: error_list[i] for i in range(len(error_list)) if error_list[i]<=0.5})
                
        for batch_idx, (images, pta, filename) in enumerate(abnormal_loader, 1):
        
            image_data = images.to(DEVICE)
            mlp_data = pta.to(DEVICE)
            mlp_data = torch.tensor(mlp_data, dtype=torch.float32)

            img_latent_variable = model.encoder1(image_data)
            mlp_latent_variable = model.encoder2(mlp_data)

            abnormal_latent = cross_attention(img_latent_variable, mlp_latent_variable)

            abnormal_pred = discriminator(abnormal_latent)

            error_num = int(abnormal_pred.size()[0]/len(filename))
            error_list = [torch.mean(abnormal_pred[i:i+error_num]).item() for i in range(0, abnormal_pred.size()[0], error_num)]

            if not return_abnormal:
                return_abnormal = {filename[i]: error_list[i] for i in range(len(error_list)) if error_list[i]<=0.5}
            else:
                return_abnormal.update({filename[i]: error_list[i] for i in range(len(error_list)) if error_list[i]<=0.5})
            
            labels.extend([0] * abnormal_pred.size(0))
            preds.extend(abnormal_pred.cpu().numpy())
    
            abnormal = abnormal + len([{filename[i]: error_list[i]} for i in range(len(error_list)) if error_list[i]])
            test_acc = test_acc + len([{filename[i]: error_list[i]} for i in range(len(error_list)) if error_list[i]<=0.5])
            abnormal_acc = abnormal_acc + len([{filename[i]: error_list[i]} for i in range(len(error_list)) if error_list[i]<=0.5])
            abnormal_error = abnormal_error + len( {filename[i]: error_list[i] for i in range(len(error_list)) if error_list[i]>0.5})
        
        preds = [1 if x > 0.5 else 0 for x in preds]
        
        return test_acc/(normal+abnormal)*100, normal_acc/normal*100, abnormal_acc/abnormal*100, normal_error/normal*100, abnormal_error/abnormal*100, return_normal, return_abnormal, labels, preds

In [None]:
f_train_acc = []
f_test_acc = []

real_acc_list = []
fake_acc_list = []
normal_acc_list = []
abnormal_acc_list = []

real_error_list = []
fake_error_list = []
normal_error_list = []
abnormal_error_list = []

for epoch in tqdm(range(1, EPOCHS+1)):
    
    train_acc, real_acc, fake_acc, real_error, fake_error, real_list = train(epoch, dual_encoder, discriminator, train_loader)
    test_acc, normal_acc, abnormal_acc, normal_error, abnormal_error, normal_list, abnormal_list, label, pred = test(epoch, dual_encoder, discriminator, test_loader, abnormal_dataloader)
    print("Epoch {} | Train acc: {: .2f}% | Test acc: {: .2f}% ".format(epoch, train_acc, test_acc))
    print("Real acc: {: .2f}% | Fake acc: {: .2f}% | Normal acc: {: .2f}% | Abnormal acc: {: .2f}%".format(real_acc, fake_acc, normal_acc, abnormal_acc))
    print("Real error: {: .2f}% | Fake error: {: .2f}% | Normal error: {: .2f}% | Abnormal error: {: .2f}%".format(real_error, fake_error, normal_error, abnormal_error))
    print('--------------------------------------------------------------------------------------------------')
        
    f_train_acc.append(train_acc)
    f_test_acc.append(test_acc)
    
    real_acc_list.append(real_acc)
    fake_acc_list.append(fake_acc)
    normal_acc_list.append(normal_acc)
    abnormal_acc_list.append(abnormal_acc)
    
    real_error_list.append(real_error)
    fake_error_list.append(fake_error)
    normal_error_list.append(normal_error)
    abnormal_error_list.append(abnormal_error)
