Imports

In [None]:
import numpy as np
import os
import sys
from tqdm import tqdm
import matplotlib.pyplot as plt

In [None]:
GPU_INDEX=4
isGPU = True


############################

NUM_EPOCHES = 50
NUM_CLASS = 100
EPOCH_THRES=5
MODEL_NAME = "resnet"
DATA_NAME="cifar"
DATA_DIR = f'./data/{DATA_NAME}-{NUM_CLASS}'


############################

thresholds = [[0.1], [0.2],[0.3], [0.4], [0.5], [0.6], [0.7],[0.8], [0.9]]
to_keep = [[False, True]]*len(thresholds)


# thresholds = [[0.1, 0.9], [0.2, 0.8], [0.3, 0.7]]
# to_keep = [[False, True, True] , [False, True, True], [False, True, True]]

# thresholds = [[0.1, 0.3, 0.7], [0.2, 0.5, 0.8]]
# to_keep = [[False, True, True, False], [False, True, True, False]]



In [None]:
 # HF model
# HF_API_TOKEN = os.getenv("HF_API_TOKEN")


if isGPU:
    os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" ## to avoid Context Switching 
    os.environ["HF_HOME"]= "/data2/meithnav/.hfcache/"
    os.environ["CUDA_VISIBLE_DEVICES"]=str(GPU_INDEX) # not changing GPU. Only 
    os.environ["WANDB_DISABLED"] = "true"
    # os.environ["CUDA_VISIBLE_DEVICES"]=str(GPU_INDEX) # not changing GPU. Only 


import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset
torch.manual_seed(42)

if isGPU:
    torch.cuda.set_device(0) ## setgpu
    print("\n\n--> CONNECTED TO GPU NO: ", torch.cuda.current_device())
    print("--> GPU_INDEX: ", GPU_INDEX)
        
    # GPU (MPS for Apple Silicon, CUDA for Nvidia GPUs, or CPU)

    torch.cuda.empty_cache() # clear GPU cache
    torch.cuda.reset_max_memory_allocated()


device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")

# APPEND ROOT DIRECTORY
sys.path.append(os.path.abspath(os.path.join(os.path.dirname('hangman'), '..')))


print(os.getcwd())


if not os.path.exists('./models'):
    os.makedirs('./models')
    

if not os.path.exists('./outputs'):
    os.makedirs('./outputs')


In [None]:
class ForgetabilityTracker:
    def __init__(self, dataset_size, device):
        self.misclassification_counts = np.zeros(dataset_size, dtype=np.int32)

    def update(self, predictions, labels, indices):
        incorrect_predictions = predictions != labels
        incorrect_predictions = incorrect_predictions.cpu().numpy()
        indices = indices.cpu().numpy()
        self.misclassification_counts[indices] += incorrect_predictions

    def get_scores(self):
        return self.misclassification_counts

ResNet 18

In [None]:
class BasicBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        # Skip connection
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        out = torch.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = torch.relu(out)
        return out

class ResNet18(nn.Module):
    def __init__(self, num_classes=NUM_CLASS):
        super(ResNet18, self).__init__()
        self.in_channels = 64

        # Initial Convolutional Layer
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(64, 2, stride=1)
        self.layer2 = self._make_layer(128, 2, stride=2)
        self.layer3 = self._make_layer(256, 2, stride=2)
        self.layer4 = self._make_layer(512, 2, stride=2)
        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512, NUM_CLASS)

    def _make_layer(self, out_channels, blocks, stride):
        strides = [stride] + [1] * (blocks - 1)
        layers = []
        for stride in strides:
            layers.append(BasicBlock(self.in_channels, out_channels, stride))
            self.in_channels = out_channels
        return nn.Sequential(*layers)

    def forward(self, x):
        out = torch.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.avg_pool(out)
        out = torch.flatten(out, 1)
        out = self.fc(out)
        return out


Training loop

In [None]:
def train_model(
    model, 
    optimizer, 
    criterion, 
    train_loader, 
    tracker=None, 
    epoch_threshold=5, 
    thresholds_arr=None, 
    to_keep=None
):
    """
    Trains a model with dynamic dataset adjustment based on forgetability and thresholds.

    Args:
        model: The neural network model.
        optimizer: Optimizer for training.
        criterion: Loss function.
        train_loader: DataLoader for the training dataset.
        tracker: ForgetabilityTracker instance, optional.
        epoch_threshold: Number of epochs after which to update the dataset.
        thresholds_arr: Threshold values (float, tuple, or list of thresholds).
        to_keep: List of booleans indicating which bins to retain.
    """
    current_loader = train_loader  # Use the initial loader for the first phase

    for epoch in range(NUM_EPOCHES):
        model.train()

        for batch_idx, (data, targets) in tqdm(enumerate(current_loader), desc=f"Epoch {epoch + 1}/{NUM_EPOCHES}"):
            data, targets = data.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = model(data)
            _, predictions = torch.max(outputs, 1)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            if tracker:
                indices = batch_idx * train_loader.batch_size + torch.arange(data.size(0)).to(device)
                tracker.update(predictions, targets, indices)

        # Update dataset every `epoch_threshold` epochs if tracker is enabled
        if tracker and thresholds_arr and (epoch + 1) % epoch_threshold == 0:
            print(f"Epoch {epoch + 1}: Evaluating forgetability and updating dataset...")

            # Get forgetability scores and normalize
            forgetability_scores = tracker.get_scores()
            normalized_scores = (forgetability_scores - np.min(forgetability_scores)) / \
                                (np.max(forgetability_scores) - np.min(forgetability_scores))

            # Multi-threshold case: Create distinct bins
            bin_indices = []
            for i, threshold in enumerate(thresholds_arr):
                if i == 0:
                    bin_indices.append(np.where(normalized_scores <= threshold)[0])
                else:
                    bin_indices.append(
                        np.where((normalized_scores > thresholds_arr[i - 1]) & (normalized_scores <= threshold))[0]
                    )
            bin_indices.append(np.where(normalized_scores > thresholds_arr[-1])[0])

            # Filter bins based on `to_keep`
            if to_keep is not None:
                if len(to_keep) != len(bin_indices):
                    raise ValueError(
                        f"Invalid `to_keep` length. Expected {len(bin_indices)} booleans, but got {len(to_keep)}."
                    )
                bin_indices = [bin for bin, keep in zip(bin_indices, to_keep) if keep]

            indices_to_keep = np.concatenate(bin_indices) if bin_indices else np.array([], dtype=int)
            print(f"Keeping {len(indices_to_keep)} out of {len(forgetability_scores)} datapoints.")

            # Update the training dataset
            current_loader = DataLoader(
                Subset(train_loader.dataset, indices_to_keep),
                batch_size=train_loader.batch_size,
                shuffle=True
            )


Evaluation

In [None]:
def evaluate_model(model, test_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, targets in test_loader:
            data, targets = data.to(device), targets.to(device)
            outputs = model(data)
            _, predictions = torch.max(outputs, 1)
            correct += (predictions == targets).sum().item()
            total += targets.size(0)
    return correct / total

Dataset

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

if NUM_CLASS==10:
    train_dataset = torchvision.datasets.CIFAR10(root=DATA_DIR, train=True, download=True, transform=transform)
    test_dataset = torchvision.datasets.CIFAR10(root=DATA_DIR, train=False, download=True, transform=transform)
elif NUM_CLASS==100: 
    train_dataset = torchvision.datasets.CIFAR100(root=DATA_DIR, train=True, download=True, transform=transform)
    test_dataset = torchvision.datasets.CIFAR100(root=DATA_DIR, train=False, download=True, transform=transform)
else: 
    train_dataset=None
    test_dataset=None




train_loader_full = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)


In [None]:
# BASELINE: Train on the entire dataset

model_baseline = ResNet18(num_classes=NUM_CLASS).to(device)
optimizer_baseline = optim.Adam(model_baseline.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

print("\n\n****\nTraining baseline model...")
train_model(model_baseline, optimizer_baseline, criterion, train_loader_full)
accuracy_baseline = evaluate_model(model_baseline, test_loader)
print(f"Baseline Model Accuracy: {accuracy_baseline * 100:.2f}%")

## testing for thres

In [None]:
train_loader_dynamic = DataLoader(train_dataset, batch_size=64, shuffle=True)


accuracies = []
for idx, thres in tqdm(enumerate(thresholds)):
  print(f"\n\n****\n-> RUNNING THRES : {thres}, MODEL: {MODEL_NAME}, DATASET: {DATA_NAME}-{NUM_CLASS}")
  model_dynamic = ResNet18().to(device)
  optimizer_dynamic = optim.Adam(model_dynamic.parameters(), lr=0.001)
  criterion = nn.CrossEntropyLoss()

  tracker_dynamic = ForgetabilityTracker(len(train_dataset), device)

  train_model(model_dynamic, optimizer_dynamic, criterion, train_loader_dynamic, tracker_dynamic, EPOCH_THRES, thres, to_keep[idx])

  accuracy_dynamic = evaluate_model(model_dynamic, test_loader)
  accuracies.append(accuracy_dynamic)
  torch.save(model_dynamic, f'./models/{MODEL_NAME}-{DATA_NAME}-{NUM_CLASS}-thres-{thres[0]}.pt')
  print(f"-> TEST ACC : {accuracy_dynamic}, THRES : {thres}")


In [None]:
plt.plot(thresholds, accuracies)
plt.xlabel('Threshold')
plt.ylabel('Accuracy')
plt.title('Accuracy vs Threshold')
plt.show()


plt.savefig(f'./outputs/{MODEL_NAME}-{DATA_NAME}-{NUM_CLASS}-strategy-{ '-'.join(map(str, to_keep[0])) }.png')