### 0. Install custom scikit-learn library and import other packages

In [4]:
! pip install ./scikit-learn-dual

Processing ./scikit-learn-dual
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Installing backend dependencies ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Building wheels for collected packages: scikit-learn
  Building wheel for scikit-learn (pyproject.toml) ... [?25l[?25hdone
  Created wheel for scikit-learn: filename=scikit_learn-1.6.dev0-cp312-cp312-linux_x86_64.whl size=11863096 sha256=a53f0382d85456592db66b1cec9a8cc45f4cdb42a11c7f28a6a43a84aa29251f
  Stored in directory: /root/.cache/pip/wheels/a7/7e/30/512a381188b2ddd7a2d201bb5b768157895c82c7dae9fadfbe
Successfully built scikit-learn
Installing collected packages: scikit-learn
  Attempting uninstall: scikit-learn
    Found existing installation: scikit-learn 1.6.1
    Uninstalling scikit-learn-1.6.1:
      Successfully uninstalled scikit-learn-1.6.1
[31mERROR: pip's dependency resolver does not currently take into account all 

In [6]:
import os
from pathlib import Path
from tqdm.auto import tqdm
import numpy as np
import wget
import torch
from torch.cuda.amp import GradScaler, autocast
from torch.nn import CrossEntropyLoss, Conv2d, BatchNorm2d
from torch.optim import SGD, lr_scheduler
from torchvision.datasets import ImageFolder
from torchvision import transforms
import torchvision
import warnings
import sys
import os

warnings.filterwarnings('ignore')

# Add the DualXDA folder to your python path
sys.path.append(os.path.abspath("./src"))
from src.explainers import DualDA

In [None]:
DATA_DIR = Path("D:\SIWY\SIWY-25Z-Jarczewski-Rozej-Jasinski\data")

### 1. Model

In [37]:
# Resnet9
class Mul(torch.nn.Module):
    def __init__(self, weight):
        super(Mul, self).__init__()
        self.weight = weight
    def forward(self, x): return x * self.weight


class Flatten(torch.nn.Module):
    def forward(self, x): return x.view(x.size(0), -1)


class Residual(torch.nn.Module):
    def __init__(self, module):
        super(Residual, self).__init__()
        self.module = module
    def forward(self, x): return x + self.module(x)


def construct_rn9(num_classes=10):
    def conv_bn(channels_in, channels_out, kernel_size=3, stride=1, padding=1, groups=1):
        return torch.nn.Sequential(
                torch.nn.Conv2d(channels_in, channels_out, kernel_size=kernel_size,
                            stride=stride, padding=padding, groups=groups, bias=False),
                torch.nn.BatchNorm2d(channels_out),
                torch.nn.ReLU(inplace=True)
        )
    model = torch.nn.Sequential(
        conv_bn(3, 64, kernel_size=3, stride=1, padding=1),
        conv_bn(64, 128, kernel_size=5, stride=2, padding=2),
        Residual(torch.nn.Sequential(conv_bn(128, 128), conv_bn(128, 128))),
        conv_bn(128, 256, kernel_size=3, stride=1, padding=1),
        torch.nn.MaxPool2d(2),
        Residual(torch.nn.Sequential(conv_bn(256, 256), conv_bn(256, 256))),
        conv_bn(256, 128, kernel_size=3, stride=1, padding=0),
        torch.nn.AdaptiveMaxPool2d((1, 1)),
        Flatten(),
        torch.nn.Linear(128, num_classes, bias=False),
        Mul(0.2)
    )
    return model

### 2. Dataset preparation

In [38]:
# Define the full transformation pipeline for both datasets
image_transform = transforms.Compose(
                        [transforms.Resize(256),
                         transforms.CenterCrop(224),
                         transforms.ToTensor(),
                         transforms.Normalize((0.4914, 0.4822, 0.4465),
                                              (0.2023, 0.1994, 0.201))])

train_dataset = ImageFolder(root='/content/task1/easy/train', transform=image_transform)
val_dataset = ImageFolder(root='/content/task1/easy/val', transform=image_transform)

# This dataset will automatically find 3 classes (cat, dog, bird)
# and assign integer labels (e.g., cat=0, dog=1, bird=2)

In [39]:
def get_dataloader(batch_size=256, num_workers=8, split='train', shuffle=False, augment=True):
    if augment:
        transforms = torchvision.transforms.Compose(
                        [torchvision.transforms.Resize(256),
                         torchvision.transforms.CenterCrop(224),
                         torchvision.transforms.RandomHorizontalFlip(),
                         torchvision.transforms.RandomAffine(0),
                         torchvision.transforms.ToTensor(),
                         torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465),
                                                          (0.2023, 0.1994, 0.201))])
    else:
        transforms = torchvision.transforms.Compose([
                                  torchvision.transforms.Resize(256), torchvision.transforms.CenterCrop(224),
                         torchvision.transforms.ToTensor(),
                         torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465),
                                                          (0.2023, 0.1994, 0.201))])

    is_train = (split == 'train')
    if is_train:
      dataset = ImageFolder(root='/content/task1/easy/train', transform=transforms)
    else:
      dataset = ImageFolder(root='/content/task1/easy/val', transform=transforms)


    loader = torch.utils.data.DataLoader(dataset=dataset,
                                         shuffle=shuffle,
                                         batch_size=batch_size,
                                         num_workers=num_workers)

    return loader

### 3. Training

In [41]:
def train(model, loader, lr=0.4, epochs=24, momentum=0.9,
          weight_decay=5e-4, lr_peak_epoch=5, label_smoothing=0.0, model_id=0):

    opt = SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)
    iters_per_epoch = len(loader)
    # Cyclic LR with single triangle
    lr_schedule = np.interp(np.arange((epochs+1) * iters_per_epoch),
                            [0, lr_peak_epoch * iters_per_epoch, epochs * iters_per_epoch],
                            [0, 1, 0])
    scheduler = lr_scheduler.LambdaLR(opt, lr_schedule.__getitem__)
    scaler = GradScaler()
    loss_fn = CrossEntropyLoss(label_smoothing=label_smoothing)

    for ep in range(epochs):
        print(ep)
        for it, (ims, labs) in enumerate(loader):
            ims = ims.cuda()
            labs = labs.cuda()
            opt.zero_grad(set_to_none=True)
            with autocast():
                out = model(ims)
                loss = loss_fn(out, labs)

            scaler.scale(loss).backward()
            scaler.step(opt)
            scaler.update()
            scheduler.step()

    return model

In [42]:
loader_for_training = get_dataloader(batch_size=32, split='train', shuffle=True)

for i in tqdm(range(1), desc='Training models..'):
    model = construct_rn9().to(memory_format=torch.channels_last).cuda()
    model = train(model, loader_for_training, model_id=i)

Training models..:   0%|          | 0/1 [00:00<?, ?it/s]

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23


In [90]:
C = 0.001
device = "cuda"
cache_dir = "/content/cache_dir"
features_dir = "/content/features_dir"

# Define a wrapper class to provide the 'classifier' attribute
class ModelWrapper(torch.nn.Module):
    def __init__(self, original_model):
        super().__init__()
        self.original_model = original_model
        # The linear layer is the second to last element in the Sequential model
        # It's at index -2 as the last one is the Mul layer
        self.classifier = original_model[-2]
        self.features = original_model[:-2]

    def forward(self, x):
        return self.original_model(x)

# Wrap the existing model instance to expose the 'classifier' attribute
model_wrapped = ModelWrapper(model)

explainer = DualDA(
    model_wrapped,
    train_dataset, # Changed from 'train' function to 'train_dataset'
    device=device,
    dir=cache_dir,
    features_dir=features_dir,
    C=C)

explainer.train()

loader_for_testing = get_dataloader(batch_size=32, split='val', shuffle=True)
for (x, y) in loader_for_testing:
    x = x.to('cuda')
    preds = model(x).argmax(dim=-1)
    attributions = explainer.explain(x, preds)

    for i in range(len(attributions)):
      print(attributions.detach().cpu().numpy()[i].argsort()[:5])
      # plt.plot(attributions.detach().cpu().numpy()[i])
      # plt.show()

Training explainer...
[367 280 282 296 265]
[367 280 282 296 247]
[367 280 282 265 296]
[367 280 282 265 296]
[367 280 282 265 296]
[367 280 282 265 247]
[367 280 282 265 247]
[367 280 282 265 296]
[282 280 367 265 247]
[367 280 282 296 247]
[367 280 282 265 296]
[367 280 282 296 265]
[367 280 282 265 296]
[367 280 282 265 296]
[367 280 282 265 296]
[367 280 282 265 296]
[367 280 282 265 296]
[367 280 282 265 247]
[367 280 282 265 296]
[367 280 282 265 247]
[367 280 282 265 296]
[367 280 282 265 296]
[367 280 282 265 296]
[367 280 282 296 265]
[367 280 282 247 265]
[367 282 280 265 296]
[367 280 282 265 296]
[367 280 282 296 265]
[367 280 282 265 296]
[367 280 282 265 247]
[367 280 282 265 296]
[367 280 282 265 296]
[367 280 282 265 296]
[282 280 367 265 296]
[367 280 282 265 247]
[367 282 280 265 296]
[367 280 282 265 296]
[367 280 282 265 296]
[282 367 280 265 296]
[367 280 282 265 296]
[367 280 282 265 296]
[367 280 282 265 296]
[367 280 282 265 296]
[367 280 282 265 296]
[367 280 2

IndexError: index 3 is out of bounds for axis 0 with size 3