In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
import pandas as pd

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define transforms for dataset

transform = transforms.Compose([
    transforms.Resize([224, 224]),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Load dataset
train_dataset = ImageFolder(root='/home/iai/Desktop/Jeewon/Study/Conference/Active_Learning/data/mvtec2/train', transform=transform)
val_dataset = ImageFolder(root='/home/iai/Desktop/Jeewon/Study/Conference/Active_Learning/data/mvtec2/val', transform=transform)
test_dataset = ImageFolder(root='/home/iai/Desktop/Jeewon/Study/Conference/Active_Learning/data/mvtec2/test', transform=transform)

# Class mapping
class_map = {i: train_dataset.classes[i] for i in range(len(train_dataset.classes))}

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
print(len(train_dataset))
print(len(val_dataset))
print(len(test_dataset))

3211
1070
1070


In [3]:
import torch.nn as nn

# Define ResNet model
class ResNet18(nn.Module):
    def __init__(self, num_classes=2):
        super(ResNet18, self).__init__()
        self.resnet18 = torchvision.models.resnet18(pretrained=True)
        self.classifier = nn.Linear(512, num_classes)

    def forward(self, x):
        x = self.resnet18.conv1(x)
        x = self.resnet18.bn1(x)
        x = self.resnet18.relu(x)
        x = self.resnet18.maxpool(x)

        x = self.resnet18.layer1(x)
        x = self.resnet18.layer2(x)
        x = self.resnet18.layer3(x)
        x = self.resnet18.layer4(x)

        x = self.resnet18.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

def train(model, optimizer, criterion, labeled_loader):
    model.train()
    train_loss = 0.0
    train_acc = 0.0
    total = 0
    for images, labels in labeled_loader:
        images, labels = images.to(device), labels.to(device)  # Move data to GPU
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        train_acc += (predicted == labels).sum().item()
    train_loss /= len(labeled_loader)
    train_acc /= total
    return train_loss, train_acc

In [4]:
import numpy as np
import torch

class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""
    def __init__(self, patience=4, verbose=False, delta=0, path='checkpoint.pt', min_epoch = 0, trace_func=print):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement. 
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
            path (str): Path for the checkpoint to be saved to.
                            Default: 'checkpoint.pt'
            trace_func (function): trace print function.
                            Default: print            
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.epoch = 0
        self.min_epoch = min_epoch
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.path = path
        self.trace_func = trace_func
    def __call__(self, val_loss, model):

        score = -val_loss
        self.epoch +=1
        
        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            
            if self.epoch < self.min_epoch:
                self.counter = 0
                self.trace_func("Not enough epoch")
            else:
                self.counter += 1
                self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')
                if self.counter >= self.patience:
                    self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss

In [5]:
from sklearn.metrics import precision_recall_fscore_support, confusion_matrix, roc_auc_score, roc_curve, auc
import matplotlib.pyplot as plt
from sklearn.preprocessing import LabelBinarizer
import warnings
import pandas as pd
warnings.filterwarnings('ignore')

test_loss_rand = []
test_acc_rand = []
test_auc_score_rand = []
test_roc_auc_rand = []
test_precision_rand = []
test_recall_rand = []
test_f1_rand = []

df_rand2 = pd.DataFrame(columns = train_dataset.classes)

num_initial_samples = 100
batch_size = 128

# Define initial labeled dataset
labeled_indices_rand = torch.randperm(len(train_dataset))[:num_initial_samples]
labeled_dataset_rand = torch.utils.data.Subset(train_dataset, labeled_indices_rand)

unlabeled_indices_rand = torch.arange(len(train_dataset))[~torch.eq(torch.arange(len(train_dataset)).unsqueeze(1), labeled_indices_rand).any(1)]
unlabeled_dataset_rand = torch.utils.data.Subset(train_dataset, unlabeled_indices_rand)

# Define data loaders
labeled_loader_rand = DataLoader(labeled_dataset_rand, batch_size=128, shuffle=True)
unlabeled_loader_rand = DataLoader(unlabeled_dataset_rand, batch_size=128, shuffle=True)

val_loader = DataLoader(val_dataset, batch_size=128, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)


step = 1
# Active learning loop
while(len(unlabeled_dataset_rand) > 0):
    print("Step number:", step)

    model_rand = ResNet18(num_classes=2).to(device)
    early_stopping_rand = EarlyStopping(patience = 20, verbose = True, min_epoch = 0)
    numSample_list = [1258,4096]
    weights = [1 - (x / sum(numSample_list)) for x in numSample_list]
    weights = torch.FloatTensor(weights).to(device)
    criterion = nn.CrossEntropyLoss(weights).to(device)
    optimizer = torch.optim.Adam(model_rand.parameters(), lr=0.0001)

    for epoch in range(200):
        # Train model on labeled dataset
        train_loss, train_acc = train(model_rand, optimizer, criterion, labeled_loader_rand)
        print(f"Step : {step}, Epoch : {epoch+1} - Train Loss: {train_loss:.4f}, Train Accuracy: {train_acc:.4f}")    

    # Evaluate model on val dataset
        model_rand.eval()
        with torch.no_grad():
            val_loss = 0.0
            val_acc = 0.0
            total = 0
            y_true = []
            y_pred = []
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model_rand(images)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                val_acc += (predicted == labels).sum().item()
                y_true.extend(labels.tolist())
                y_pred.extend(predicted.tolist())
            val_loss /= len(val_loader)
            val_acc /= total

            precision, recall, f1, _ = precision_recall_fscore_support(y_true, y_pred, average = 'macro')
            # Compute AUC score and ROC curve
            fpr, tpr, thresholds = roc_curve(y_true, y_pred)
            auc_score = roc_auc_score(y_true, y_pred)
            roc_auc = auc(fpr, tpr)
            
            # Early Stopping Condition
            early_stopping_rand(val_loss, model_rand)
            
            if early_stopping_rand.early_stop:
                print("Early stopping")
                print(f"Val Precision: {precision.item():.4f}, Val Recall: {recall.item():.4f}, Val F1 Score: {f1.item():.4f}")
                break
            
            print(f"Val Loss: {val_loss:.4f}, Val Accuracy: {val_acc:.4f}")            
            print(f"AUC Score: {auc_score:.4f}")
            print("")
            
        model_rand.load_state_dict(torch.load('checkpoint.pt'))

    
    # Evaluate model on Test Dataset
    model_rand.load_state_dict(torch.load('checkpoint.pt'))
    model_rand.eval()
    with torch.no_grad():
        test_loss = 0.0
        test_acc = 0.0
        total = 0
        y_true = []
        y_pred = []
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model_rand(images)
            loss = criterion(outputs, labels)
            test_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            test_acc += (predicted == labels).sum().item()
            y_true.extend(labels.tolist())
            y_pred.extend(predicted.tolist())
        test_loss /= len(test_loader)
        test_acc /= total
        
        # Compute AUC score and ROC curve
        fpr, tpr, thresholds = roc_curve(y_true, y_pred)
        auc_score_rand = roc_auc_score(y_true, y_pred)
        roc_auc_rand = auc(fpr, tpr)

        precision_rand, recall_rand, f1_rand, _ = precision_recall_fscore_support(y_true, y_pred, average = 'binary')
        test_loss_rand.append(test_loss)
        test_acc_rand.append(test_acc)
        test_auc_score_rand.append(auc_score_rand)
        test_roc_auc_rand.append(roc_auc_rand)
        test_precision_rand.append(precision_rand)
        test_recall_rand.append(recall_rand)
        test_f1_rand.append(f1_rand)

        
        print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc:.4f}")
        print(f"AUC Score: {auc_score_rand:.4f}")
    

    # Iteration end condition
    if len(labeled_dataset_rand) > int(len(train_dataset)/2):
        break
    
    # Make model predictions on unlabeled dataset
    model_rand.eval()
    predictions = []
    with torch.no_grad():
        for images, _ in unlabeled_loader_rand:
            images = images.to(device)
            output = model_rand(images)
            predictions.append(output)
    predictions = torch.cat(predictions, dim=0)

    # Select samples to label using random selection method
    idx = torch.randperm(len(unlabeled_dataset_rand))[:20]


    # Count selected classes
    class_counts = {class_name: 0 for class_name in train_dataset.classes}

    for i in idx:
        class_index = torch.argmax(predictions[i]).item()
        class_name = train_dataset.classes[class_index]
        class_counts[class_name] += 1

    new_row = {class_name: class_counts[class_name] for class_name in train_dataset.classes}
    df_rand2 = df_rand2.append(new_row, ignore_index=True)

    # print("data frame update")
    # print(df_rand2)



    labeled_indices_rand = torch.cat([labeled_indices_rand, unlabeled_indices_rand[idx]])
    unlabeled_indices_rand = unlabeled_indices_rand[~torch.eq(unlabeled_indices_rand.unsqueeze(1), unlabeled_indices_rand[idx]).any(1)]

    labeled_dataset_rand = torch.utils.data.Subset(train_dataset, labeled_indices_rand)
    unlabeled_dataset_rand = torch.utils.data.Subset(train_dataset, unlabeled_indices_rand)
    print("Updated Length of labeled dataset : ",len(labeled_dataset_rand))
    print("Length of unlabeled dataset : ",len(unlabeled_dataset_rand))
    print("-----------------------------------------------------------")

    # Update labeled and unlabeled data loaders
    labeled_loader_rand = DataLoader(labeled_dataset_rand, batch_size=batch_size, shuffle=True)
    unlabeled_loader_rand = DataLoader(unlabeled_dataset_rand, batch_size=batch_size, shuffle=True)

    step+=1

Step number: 1
Step : 1, Epoch : 1 - Train Loss: 0.7665, Train Accuracy: 0.6500
Validation loss decreased (inf --> 0.635699).  Saving model ...
Val Loss: 0.6357, Val Accuracy: 0.6411
AUC Score: 0.5266

Step : 1, Epoch : 2 - Train Loss: 0.4308, Train Accuracy: 0.8300
Validation loss decreased (0.635699 --> 0.615950).  Saving model ...
Val Loss: 0.6160, Val Accuracy: 0.6486
AUC Score: 0.5328

Step : 1, Epoch : 3 - Train Loss: 0.3089, Train Accuracy: 0.9400
Validation loss decreased (0.615950 --> 0.598472).  Saving model ...
Val Loss: 0.5985, Val Accuracy: 0.6860
AUC Score: 0.5393

Step : 1, Epoch : 4 - Train Loss: 0.2306, Train Accuracy: 0.9400
Validation loss decreased (0.598472 --> 0.587048).  Saving model ...
Val Loss: 0.5870, Val Accuracy: 0.7037
AUC Score: 0.5467

Step : 1, Epoch : 5 - Train Loss: 0.1660, Train Accuracy: 0.9800
Validation loss decreased (0.587048 --> 0.581009).  Saving model ...
Val Loss: 0.5810, Val Accuracy: 0.6972
AUC Score: 0.5356

Step : 1, Epoch : 6 - Train Lo

In [None]:
print(test_loss_rand)
print(test_acc_rand)
print(test_auc_score_rand)
print(test_roc_auc_rand)
print(test_precision_rand)
print(test_recall_rand)
print(test_f1_rand)

[2.9751770430141025, 3.327148523595598, 1.2495696577760909, 1.543341663148668, 1.375522806826565, 1.0428146322568257, 0.94167387433764, 1.0305057614265631, 0.685347760717074, 3.827482803020151, 0.8077805936336517, 0.6210329023500284, 1.142158208621873, 0.758786403056648, 0.931375707189242, 0.8257470097806718, 0.8271802928712633, 0.6830103016561933, 0.9773798783620199, 0.7345955570538839, 0.6768564374910461, 0.9018977714909447, 0.6088952223459879, 0.7106225879655944, 0.7856138498969408, 0.6099977774752511, 0.7491737041208479, 1.2267359139190779, 0.6947953237427605, 0.5611807559099462, 0.5941186795632044, 0.5335849894003736, 0.6079886878530184, 0.6935367923643854, 0.48401052090856767, 0.6110255006286833, 0.5426083447204696, 0.4439196570052041, 0.6422177826364835, 0.4861729145050049, 0.6269412967893813, 0.6242279741499159, 0.4888133505980174, 0.5351292077038023, 0.5227813969055811, 0.48077916353940964, 0.5202283023132218, 0.5137048901783096, 0.4656278147465653, 0.4933744050148461, 0.43091