In [1]:
import os
import numpy as np

import matplotlib.pyplot as plt
%matplotlib inline  

import warnings
warnings.filterwarnings("ignore")

In [2]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data.dataset import Dataset


In [3]:
train_transform = transforms.Compose([
            transforms.RandomCrop(224, padding=24),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(0.05, 0.05, 0.05, 0.05),
            transforms.ToTensor(),
        ])


test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

test_datasset = datasets.ImageFolder(root='./data/rsna_pneumonia/test/', transform=test_transform)
val_datasset = datasets.ImageFolder(root='./data/rsna_pneumonia/val/', transform=test_transform)
train_datasset = datasets.ImageFolder(root='./data/rsna_pneumonia/train/', transform=train_transform)

In [4]:
train_loader = torch.utils.data.DataLoader(train_datasset, batch_size=128,
                                          shuffle=True, num_workers=8)
test_loader = torch.utils.data.DataLoader(test_datasset, batch_size=128,
                                          shuffle=False, num_workers=8)
val_loader = torch.utils.data.DataLoader(val_datasset, batch_size=128,
                                          shuffle=False, num_workers=8)

In [5]:
def train_classifer_epoch(net, trainloader,
                optimizer, device):
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    loss_func = nn.CrossEntropyLoss()
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        targets = targets.to(device)
        optimizer.zero_grad()
        outputs = net(inputs.to(device))
        loss = loss_func(outputs, targets)
        loss.backward()
        optimizer.step()
        predicted = outputs.argmax(dim=1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        train_loss += loss.item()
    train_loss = train_loss/(batch_idx+1)
    train_acc = correct/total
    return train_acc, train_loss

def test_classifer_accuracy(net, testloader, device):
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    loss_func = nn.CrossEntropyLoss()
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            targets = targets.to(device)
            outputs = net(inputs.to(device))
            loss = loss_func(outputs, targets)
            test_loss += loss.item()
            predicted = outputs.argmax(dim=1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
    test_acc = correct/total
    test_loss = test_loss/(batch_idx+1)
    return test_acc, test_loss

def get_preds_labels(net, testloader, device):
    net.eval()
    preds = []
    true_labels = []
    sigmoid = nn.Sigmoid()
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            outputs = net(inputs.to(device))
            predicted = torch.softmax(outputs, dim=1).detach()
            preds.append(predicted)
            true_labels.append(targets)
        preds = torch.cat(preds)
        true_labels = torch.cat(true_labels)
    return preds.cpu().numpy(), true_labels.cpu().numpy()

In [6]:
def train(net, epochs, trainloader, validloader,
          seed, save_path, device):
    best_acc = -np.inf
    test_acc_history = []
    optimizer = optim.AdamW(net.parameters(),
                          lr=5e-5,)#     weight_decay=5e-4
    for epoch in range(epochs):
        train_acc, train_loss  = train_classifer_epoch(net, trainloader, optimizer,
                                                      device)
        test_acc, test_loss = test_classifer_accuracy(net, validloader, device)
        test_acc_history.append(test_acc)
        print(f'epoch ({epoch+1})| Train loss {round(train_loss, 2)}| Train accuracy {round(train_acc, 2)}| Test accuracy {round(test_acc, 2)}| Test loss {round(test_loss, 2)}')
        if best_acc < test_acc:
            print('Saving model...')
            model_state = {'net': net.state_dict(),
                           'opti': optimizer.state_dict(),
                           'epoch': epoch,
                           'seed': seed,
                           'acc': test_acc,'epoch': epoch,
                           'test_acc_history':test_acc_history}
            torch.save(model_state, save_path)
            best_acc = test_acc

In [7]:
# torchvision.models.list_models()

In [8]:

def modify_model_output_classes(model, num_classes):
    if hasattr(model, 'fc'):
        # Common case for models like ResNet
        in_features = model.fc.in_features
        new_classifier = nn.Linear(in_features, num_classes).to(device)
        model.fc = new_classifier
        return model
    
    elif hasattr(model, 'classifier'):
        if isinstance(model.classifier, nn.Linear):
            # Common case for models like densenet, vgg
            in_features = model.classifier.in_features
            new_classifier = nn.Linear(in_features, num_classes).to(device)
            model.classifier = new_classifier
            return model
        
        if isinstance(model.classifier, nn.Sequential) and isinstance(model.classifier[1], nn.Linear):
            # Common case for models like efficientnet_b2
            in_features = model.classifier[1].in_features
            new_classifier = nn.Linear(in_features, num_classes).to(device)
            model.classifier[1] = new_classifier
            return model
    elif hasattr(model, 'heads'):
        # ViT Special case for models with 'heads' as the final layer
        in_features = model.heads.head.in_features
        new_classifier = nn.Linear(in_features, num_classes).to(device)
        model.heads.head = new_classifier
        return model
    elif hasattr(model, 'head'):
        # Swin
        in_features = model.head.in_features
        new_classifier = nn.Linear(in_features, num_classes).to(device)
        model.head = new_classifier
        return model
    else:
        raise ValueError("Unsupported model architecture. Cannot modify output classes.")


In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

seeds = list(range(16, 20))
# 'resnet18' , 
# 'vit_b_16'
# 'swin_s', - one is missing
# 'efficientnet_b2',
seed_counter = 0
for model_type in [  'densenet121']:
    for i in range(4):
        seed = seeds[seed_counter]
        torch.random.manual_seed(seed)
        model = torch.hub.load("pytorch/vision", model_type,
                               weights="IMAGENET1K_V1").to(device)
        model = modify_model_output_classes(model, 3)
        train(model, 100, train_loader, test_loader, seed,
              f'./models/xray/{model_type}_{i}.ckpt', device)
        seed_counter+=1
    break


Using cache found in /home/guy5/.cache/torch/hub/pytorch_vision_main


epoch (1)| Train loss 0.71| Train accuracy 0.67| Test accuracy 0.7| Test loss 0.65
Saving model...
epoch (2)| Train loss 0.61| Train accuracy 0.72| Test accuracy 0.72| Test loss 0.62
Saving model...
epoch (3)| Train loss 0.58| Train accuracy 0.74| Test accuracy 0.72| Test loss 0.61
Saving model...
epoch (4)| Train loss 0.55| Train accuracy 0.75| Test accuracy 0.72| Test loss 0.59
Saving model...
epoch (5)| Train loss 0.52| Train accuracy 0.77| Test accuracy 0.74| Test loss 0.58
Saving model...
epoch (6)| Train loss 0.49| Train accuracy 0.78| Test accuracy 0.75| Test loss 0.58
Saving model...
epoch (7)| Train loss 0.47| Train accuracy 0.8| Test accuracy 0.76| Test loss 0.55
Saving model...
epoch (8)| Train loss 0.43| Train accuracy 0.82| Test accuracy 0.75| Test loss 0.61
epoch (9)| Train loss 0.4| Train accuracy 0.83| Test accuracy 0.75| Test loss 0.59
epoch (10)| Train loss 0.36| Train accuracy 0.85| Test accuracy 0.75| Test loss 0.64
epoch (11)| Train loss 0.33| Train accuracy 0.86| 

In [None]:
# for model_type in ['resnet18' ,  'vit_b_16','swin_s', 'efficientnet_b2', 'densenet121']: 
#     model = torch.hub.load("pytorch/vision", model_type,
#                                weights="IMAGENET1K_V1").to(device)
#     model = modify_model_output_classes(model, 10)

In [None]:
net.load_state_dict(torch.load('./models/x_ray_victim.ckpt')['net'])

In [None]:
preds, labels = get_preds_labels(net, test_loader, device)

In [None]:
from sklearn.metrics import roc_auc_score, roc_curve, auc

def roc_curve_ood(preds, labels):
    fpr, tpr, thresholds = roc_curve(labels, preds)
    thres_idx = np.argmin(np.square(tpr-0.95))
    thres95 = thresholds[thres_idx]
    auc_score = roc_auc_score(labels, preds)
    return thres95, auc_score


thres95, auc = roc_curve_ood(preds, labels)
print(f'AUROC: {auc}')

### Adversarial attack(un-adaptive)

In [None]:
def GradientPGDAttack(model, inputs, labels,
                                   loss_func, eps, device):
    pertubation = torch.zeros(inputs.size()).to(device)
    samples = inputs.to(device)
    for i in range(50):
        pertubation = torch.autograd.Variable(pertubation, requires_grad=True)
        classifer_outputs = model((samples+pertubation))
        classifier_labels = labels.to(device)
        loss = loss_func(classifer_outputs, classifier_labels)
        loss.backward()
        gradient = torch.ge(pertubation.grad.data, 0)
        pertubation = torch.add(pertubation.data, +(eps/10)*gradient).detach()    
        pertubation = torch.clamp(pertubation, min = -eps, max= eps)
        del classifer_outputs, classifier_labels, gradient, loss
    return (samples+pertubation)

def gen_adv_dataset(model, loader, eps):
    data = []
    labels = []
    for samples, targets in loader:
        labels.append(targets)
        samples, targets = samples.to(device), targets.to(device)

        pertubed = GradientPGDAttack(model, samples, targets,
                                  BinaryCrossEntropy(),
                                  eps, device).detach().cpu()
        data.append(pertubed)
    data = torch.cat(data)
    labels = torch.cat(labels)
    dataset = torch.utils.data.TensorDataset(data, labels)
    return torch.utils.data.DataLoader(dataset, batch_size=128,
                      shuffle=False,num_workers =6,pin_memory = True)

In [None]:
adv_dataloader = gen_adv_dataset(net, test_loader, 8/255)

In [None]:
preds, labels = get_preds_labels(net, adv_dataloader, device)
thres95, auc = roc_curve_ood(preds, labels)
print(f'AUROC: {auc}')

# MGM

In [None]:
from sklearn.metrics import roc_auc_score
from sklearn import metrics

In [None]:
def net_pen_rep(net, x):
    out = net.conv1(x)
    out = net.bn1(out)
    out = net.relu(out)
    out = net.layer1(out)
    out = net.layer2(out)
    out = net.layer3(out)
    out = net.layer4(out)
    out = net.avgpool(out)
    return out

def get_pen_reps(net, loader, device):
    net.eval()
    reps = []
    with torch.no_grad():
        for (inputs, targets) in loader:
            inputs= inputs.to(device)
            input_reps = net_pen_rep(net, inputs)
            input_reps = input_reps.detach().cpu()
            reps.append(input_reps)
        reps = torch.cat(reps).squeeze()
        return reps    
    
def calc_MGM_params(net, trainloader):
    train_reps = get_pen_reps(net,
                              trainloader,
                              device).t()
    m = train_reps.size(1) 
    f_dim = train_reps.size(0)
    mu = train_reps.mean(dim = 1,  keepdim=True)
    train_reps -= mu
    cov = (1/ (m+1)) * train_reps.matmul(train_reps.t()) +1e-10*torch.eye(f_dim, f_dim)
    R = torch.cholesky(cov, upper=False)
    R_diag_sum = R.diag().sum()
    R_inv = torch.cholesky_inverse(R)
    return f_dim, R_diag_sum, R_inv, mu
    
def calc_likelihood(rep, f_dim, R_diag_sum, R_inv, mu):
    Z = -0.5*R_diag_sum -f_dim*np.log(2*np.pi)
    log_exp = (R_inv.matmul(rep - mu)**2).sum(dim=0)
    return Z+log_exp
    
def predict_liklihood(net, loader,f_dim, R_diag_sum, R_inv, mu):
    net.eval()
    preds = []
    with torch.no_grad():
        for (inputs, targets) in loader:
            inputs= inputs.to(device)
            reps = net_pen_rep(net, inputs)
            reps = reps.detach().cpu()
            pred = calc_likelihood(reps.squeeze().t(), f_dim, 
                                   R_diag_sum, R_inv, mu)
            preds.append(pred)
        preds = torch.cat(preds)
    return preds

In [None]:
f_dim, R_diag_sum, R_inv, mu = calc_MGM_params(net, train_loader)

In [None]:
adv_likelihood = predict_liklihood(net, adv_dataloader, f_dim, R_diag_sum, R_inv, mu)
bengin_likelihood = predict_liklihood(net, test_loader, f_dim, R_diag_sum, R_inv, mu)

In [None]:
plt.hist(adv_likelihood.numpy(), bins=100, color='r')
plt.hist(bengin_likelihood.numpy(), bins=100, color='b')
plt.show()