In [1]:
import sys
sys.path.append("..")

from models.modeling import VisionTransformer, CONFIGS
from data_utils.cifarn_dataset import CIFAR10
from models_vgg import vgg11_bn
from models_dense import densenet121
from models_resnet import resnet50
import numpy as np
import torch
import gc
from datetime import datetime
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import models, transforms
from torch.utils.tensorboard import SummaryWriter
import torch.optim as optim
from torchvision.models.feature_extraction import create_feature_extractor
from torchsummary import summary
import torchmetrics.functional.classification as M
from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix, classification_report
device = torch.device("cuda:0" if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else "cpu")

In [None]:
k=2 # set K

# set below according to the number of times the noisy labels were augmented.
# when K=2, there are 2 clusters. 

augs = ['a1', 'a2', 'a3'] # When training m_theta for cluster 1 uncomment this and comment below
augs = ['a4', 'a5', 'a6'] # When training m_theta for cluster 2 uncomment this and comment above

# this is to follow how training sets were created in identifying_annot_profiles.ipynb.

In [2]:
# helpers
def evaluate(model, dataloader, device, show=True, criterion=None): # to test the model
    # model results
    model.eval()
    with torch.no_grad():
        predictions = torch.tensor([]).to(device)
        groundTruth = torch.tensor([]).to(device)
        for i, (imgs, labels, n_labels) in enumerate(dataloader):
            imgs = imgs.to(device)
            labels = labels.to(device)
            n_labels = n_labels.to(device)
            n_labels = F.one_hot(n_labels, num_classes=10).to(device, dtype=torch.float32)

            base, user, outputs = model(imgs, n_labels)
            if criterion:
                loss = criterion(outputs, torch.argmax(n_labels, dim=1), labels)
            _, preds = torch.max(outputs, 1)
            
            predictions = torch.cat((predictions, preds))
            groundTruth = torch.cat((groundTruth, labels))

    predictions = predictions.detach().cpu().numpy()
    groundTruth = groundTruth.detach().cpu().numpy()

    if show:
        print(classification_report(y_true=groundTruth, y_pred=predictions))
        ConfusionMatrixDisplay(confusion_matrix(groundTruth, predictions)).plot()
        return groundTruth, predictions
    else:
        return classification_report(y_true=groundTruth, y_pred=predictions, output_dict=True)['accuracy'], loss.item() if criterion else None
    
def getClusterNoiseMatrix(trainloader):
    predictions = torch.tensor([]).to(device)
    groundTruth = torch.tensor([]).to(device)

    for i, (imgs, labels, n_labels) in enumerate(trainloader):
        labels = labels.to(device)
        n_labels = n_labels.to(device)

        predictions = torch.cat((predictions, n_labels))
        groundTruth = torch.cat((groundTruth, labels))

    predictions = predictions.detach().cpu().numpy()
    groundTruth = groundTruth.detach().cpu().numpy()

    # to get the cluster noise
    noise_matrix = confusion_matrix(groundTruth, predictions, normalize='true')
    return noise_matrix

class CorrectionLoss(nn.Module):
    def __init__(self, loss1, C=0, N_human=None, N_base=None):
        super().__init__()
        self.loss1 = loss1
        self.C = C
        self.N_h = N_human
        self.N_b = N_base
    
    def noiseCorrection(self, prediction, n_input):
        softmax_pred = F.softmax(prediction, dim=1)
        correction = torch.tensor([]).to(device=device)

        if self.N_h is not None:
            loss_h = F.nll_loss(torch.log(torch.matmul(softmax_pred, self.N_h)), n_input, reduction='mean')
            correction = torch.cat((correction, torch.tensor([loss_h]).to(device)))

        if self.N_b is not None:
            loss_b = F.nll_loss(torch.log(torch.matmul(softmax_pred, self.N_b)), n_input, reduction='mean')
            correction = torch.cat((correction, torch.tensor([loss_b]).to(device)))

        return self.C * torch.mean(correction)
             
    def forward(self, prediction, n_input, target):
            l = self.loss1(prediction, target) 
            correction = self.noiseCorrection(prediction, n_input)
            return l+correction
    
def saveModel(model, modelName):
    model_path = 'adapt_models/' + modelName + '_{}'.format(datetime.now().strftime('%Y%m%d_%H%M%S'))
    torch.save(model.state_dict(), model_path)

In [3]:
torch.cuda.empty_cache()
gc.collect()
torch.random.manual_seed(0)

transform_train = transforms.Compose([
    transforms.RandomResizedCrop((224, 224), scale=(0.05, 1.0)),
    transforms.AutoAugment(transforms.AutoAugmentPolicy.CIFAR10),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])

transform_test = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])

trainset = CIFAR10(root="./data/", k=k, augs=augs, train=True, download=False, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=512, shuffle=True, num_workers=1)

testset = CIFAR10(root="./data/", k=k, augs=augs, train=False, download=False, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=512, shuffle=False, num_workers=1)

classes = ('airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Any model architecture of your choice can be used for base model and assign it as self.base_model below.
Vit is used in this sample.

Output of the base model should be the a logit having the size of no. of classes.

In [4]:
class AdaptedAI(nn.Module):
    def __init__(self):
        super(AdaptedAI, self).__init__()
        self.base_model = VisionTransformer(CONFIGS['ViT-B_16'], 224, zero_head=True, num_classes=10)
        # loading the trained base model with consensus labels (at the end of the crowdlab process)
        # self.base_model.load_state_dict(torch.load("path/to/base_model.bin", map_location=torch.device(device)))

        self.base_model.to(device)
        for param in self.base_model.parameters():
            param.requires_grad = False

        # to encode the noisy lable
        self.n_l_encoder = nn.Sequential(
            nn.Linear(10, 32),
            nn.ReLU(),
            nn.Linear(32, 10)
        )

        # input -> img + noisy lable
        self.decision_ai = nn.Sequential(
            nn.Linear(20, 64),
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, 10),
        )

    def forward(self, imgs, n_l):
        img_features = self.base_model(imgs)[0] # [0] is because of how vit is designed
        n_l_features = self.n_l_encoder(n_l)

        out = torch.cat((img_features, n_l_features), dim=1)
        out = self.decision_ai(out)
        return img_features, n_l, out

In [5]:
adapt_model = AdaptedAI().to(device)
for param in adapt_model.base_model.parameters():
    param.requires_grad = False
    
# make trainable base model layers as needed
adapt_model.base_model.head.requires_grad_(True)

# loading the trained m_theta if needed
# adapt_model.load_state_dict(torch.load('./adapt_models/k1_vitb16_l.1_20230817_181808', map_location=torch.device(device)))

Linear(in_features=768, out_features=10, bias=True)

In [6]:
# getting the cluster noise matrix
noise_matrix = getClusterNoiseMatrix(trainloader)
noise_H = torch.tensor(noise_matrix).to(dtype=torch.float32, device=device)

# additionally, can use the base model's noise matrix as well, if needed.
# check CorrectionLoss implementation
noise_B = np.array([
])
noise_B = torch.tensor(noise_B).to(dtype=torch.float32, device=device)

In [7]:
criterion = CorrectionLoss(nn.CrossEntropyLoss(), 0.1, noise_H)
optimizer = optim.Adam(adapt_model.parameters())

In [8]:
best_acc = 0.90
global_step = 0.0
writer = SummaryWriter('./adapt_logs/training')

In [None]:
torch.cuda.empty_cache()
gc.collect()

EPOCHS = 1
for e in range(EPOCHS):
    for i, (imgs, labels, n_labels) in enumerate(trainloader):
        imgs = imgs.to(device)
        labels = labels.to(device)
        n_labels = n_labels.to(device)
        n_labels = F.one_hot(n_labels, num_classes=10).to(device, dtype=torch.float32)
        
        optimizer.zero_grad()
        _, _, outputs = adapt_model(imgs, n_labels)

        loss = criterion(outputs, torch.argmax(n_labels, dim=1), labels)
        loss.backward()
        optimizer.step()

        running_loss = loss.item()
        _, preds = torch.max(outputs, 1)
        running_acc = torch.sum(preds == labels.data).item() / len(labels)

        print("epoch:{}, batch:{}, loss:{}, acc:{}".format(e+1, i+1, running_loss, running_acc))
        global_step += 1
        writer.add_scalar("Train/Loss", running_loss, global_step)
        writer.add_scalar("Train/Acc", running_acc, global_step)
    
    test_acc, test_loss = evaluate(adapt_model, testloader, device, False, criterion)
    writer.add_scalar("Test/Loss", test_loss, global_step)
    writer.add_scalar("Test/Acc", test_acc, global_step)
    if test_acc > best_acc:
        print("Accuracy", test_acc)
        best_acc = test_acc
        saveModel(adapt_model, "model")

writer.flush()

In [None]:
evaluate(adapt_model, testloader, device, False)