In [37]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.autograd
import torchvision
import matplotlib.pyplot as plt
import numpy as np
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, WeightedRandomSampler


from model import AlexNet, LeNet


In [41]:
## Test learning to reweight model
# from main import Reweighting


# torch.backends.cudnn.enabled = False
# torch.manual_seed(1)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("device = ", device)

class Reweighting():
    def __init__(self, network, hyperparameters, criterion, criterion_mean, optimizer, train_loader, valid_loader, test_loader):
        self.network = network.requires_grad_(requires_grad=True)
        self.hyperparameters = hyperparameters
        self.criterion = criterion
        self.optimizer = optimizer
        self.train_loader = train_loader
        self.valid_loader = valid_loader
        self.test_loader = test_loader
        self.criterion_mean = criterion_mean
        self.gradient_network = None

    def train(self):
        # Train the network
        for epoch in range(self.hyperparameters['n_epochs']):
            self.network.train()
            for batch_idx, (data, target) in enumerate(self.train_loader):
                data = data.to(device)
                target = target.to(device)

                self.optimizer.zero_grad()
                output = self.network(data)
                # print(output)
                loss = self.criterion_mean(output, target)
                # print(loss)
                loss.backward()
                self.optimizer.step()
                if batch_idx % self.hyperparameters['log_interval'] == 0:
                    print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                        epoch, batch_idx * len(data), len(self.train_loader.dataset),
                        100. * batch_idx / len(self.train_loader), loss.item()))


    def test(self):
        self.network.eval()

        acc = []
        for itr,(test_img, test_label) in enumerate(self.test_loader):
            prediction = self.network(test_img.to(device)).detach().cpu().numpy()
            # print(prediction)
            prediction = np.argmax(prediction, axis=1)
            tmp = (prediction == test_label.detach().numpy())
            print(prediction)
            print(test_label)
            acc.append(tmp)

        accuracy = np.concatenate(acc).mean()
        return np.round(accuracy*100,2)


device =  cuda:0


In [44]:
train_folder = './dataset/train_unbiased'
test_folder = './dataset/test_unbiased'
validate_folder = './dataset/validate_unbiased'

class_weights = [0.3, 0.1, 0.3, 0.3]  # Example weights for each class

transform = transforms.Compose([
    transforms.ToTensor(),  # Convert images to PyTorch tensors
    transforms.Normalize(mean=[119.93560047938638, 121.99074304889741, 129.42976558005753], std=[65.61636024791385, 64.00107977894356, 60.628164585048744]),  # Normalize images
])

train_dataset = datasets.ImageFolder(train_folder, transform=transform)
weights = [class_weights[label] for label in train_dataset.targets]
train_sampler = WeightedRandomSampler(weights, len(weights), replacement=True)

test_dataset = datasets.ImageFolder(test_folder, transform=transform)
weights = [class_weights[label] for label in test_dataset.targets]
test_sampler = WeightedRandomSampler(weights, len(weights), replacement=True)

validate_dataset = datasets.ImageFolder(validate_folder, transform=transform)
weights = [class_weights[label] for label in validate_dataset.targets]
valid_sampler = WeightedRandomSampler(weights, len(weights), replacement=True)


# number of epoch and log interval reduced for testing
hyperparameters = {
    'n_epochs' : 50,
    'batch_size' : 50,
    'learning_rate' : 1e-3,
    'momentum' : 0.5,
    'log_interval' : 4
}

network = AlexNet()

criterion = nn.CrossEntropyLoss(reduction='none')
criterion_mean = nn.CrossEntropyLoss(reduction='mean')

optimizer = optim.SGD(network.params(),
                        lr=hyperparameters['learning_rate'],
                        momentum=hyperparameters['momentum'])

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=hyperparameters['batch_size'], sampler=train_sampler)
test_loader = DataLoader(test_dataset, batch_size=hyperparameters['batch_size'], sampler=test_sampler)
valid_loader = DataLoader(validate_dataset, batch_size=hyperparameters['batch_size'], sampler=valid_sampler)

our_model = Reweighting(network, hyperparameters, criterion, criterion_mean, optimizer, train_loader, valid_loader, test_loader)

start_accuracy = our_model.test()
print("Starting accuracy = ", start_accuracy)

our_model.train()

end_accuracy = our_model.test()
print("Ending accuracy = ", end_accuracy)



[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0]
tensor([0, 3, 3, 1, 1, 1, 3, 1, 2, 1, 1, 1, 1, 0, 1, 3, 2, 0, 2, 2, 0, 3, 3, 1,
        0, 1, 2, 1, 2, 2, 3, 2, 1, 0, 1, 0, 3, 3, 2, 1, 1, 1, 1, 1, 3, 0, 1, 3,
        1, 1])
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0]
tensor([2, 1, 3, 3, 1, 2, 2, 3, 0, 0, 3, 1, 0, 1, 1, 1, 0, 3, 3, 0, 2, 3, 2, 0,
        1, 2, 0, 0, 2, 2, 1, 3, 1, 2, 3, 1, 3, 1, 0, 1, 0, 1, 1, 0, 1, 1, 0, 1,
        1, 0])
[0 0 0 0]
tensor([1, 0, 2, 1])
Starting accuracy =  21.15
[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1]
tensor([0, 1, 1, 0, 0, 2, 1, 0, 3, 2, 1, 1, 3, 1, 3, 0, 0, 0, 0, 0, 1, 0, 2, 1,
        1, 3, 0, 2, 2, 3, 1, 3, 3, 1, 1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 3, 1, 1, 0,
        1, 0])
[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1]
tenso

In [15]:
## Test control model
# from main import Reweighting


# torch.backends.cudnn.enabled = False
# torch.manual_seed(1)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("device = ", device)

class NoReweighting():
    def __init__(self, network, hyperparameters, criterion, optimizer, train_loader, test_loader):
        self.network = network
        self.hyperparameters = hyperparameters
        self.criterion = criterion
        self.optimizer = optimizer
        self.train_loader = train_loader
        self.train_loader
        self.test_loader = test_loader

    def train(self):
        # Train the network
        for epoch in range(hyperparameters['n_epochs']):
            self.network.train()
            for batch_idx, (data, target) in enumerate(train_loader):
                data = data.to(device)
                target = target.to(device)

                self.optimizer.zero_grad()
                output = self.network(data)
                loss = self.criterion(output, target)
                loss.backward()
                self.optimizer.step()
                # if batch_idx % hyperparameters['log_interval'] == 0:
                #     print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                #         epoch, batch_idx * len(data), len(train_loader.dataset),
                #         100. * batch_idx / len(train_loader), loss.item()))


    def test(self):
        test_loss = 0
        correct = 0
        for batch_idx, (data, target) in enumerate(test_loader):
            self.network.eval()
            output = self.network(data.to(device)).cpu()
            test_loss += self.criterion(output, target).item()
            pred = output.data.max(1, keepdim=True)[1]
            correct += pred.eq(target.data.view_as(pred)).sum()

        # print('\nTest set: \nAvg. loss: {:.4f} \nAccuracy: {}/{} ({:.2f}%)\n'.format(
        #     test_loss, correct, len(test_loader.dataset),
        #     100.0 * correct / len(test_loader.dataset)))
        
        return (100.0 * correct / len(test_loader.dataset)).item()

# number of epoch and log interval reduced for testing
hyperparameters = {
    'n_epochs' : 2,
    'batch_size_train' : 100,
    'batch_size_valid' : 10,
    'batch_size_test' : 1000,
    'learning_rate' : 1e-3,
    'momentum' : 0.5,
    'log_interval' : 1
}

network = LeNet()

criterion = nn.CrossEntropyLoss(reduction='none')
criterion_mean = nn.CrossEntropyLoss(reduction='mean')

optimizer = optim.SGD(network.params(),
                        lr=hyperparameters['learning_rate'],
                        momentum=hyperparameters['momentum'])
# Load the data
data_loader = MNISTDataLoader(validation_ratio=0.05,
                            batch_size_train=hyperparameters['batch_size_train'],
                            batch_size_valid=hyperparameters['batch_size_valid'],
                            batch_size_test=hyperparameters['batch_size_test'])

# 10% representation by 9's class in training data
desired_sample_distribution = [100, 100, 100, 100, 100, 100, 100, 100, 100, 10]
data_loader.sample_bias(desired_sample_distribution, dataset="train")


train_loader = data_loader.train_dataloader
valid_loader = data_loader.valid_dataloader
test_loader = data_loader.test_dataloader


our_model = NoReweighting(network, hyperparameters, criterion_mean, optimizer, train_loader, test_loader)

start_accuracy = our_model.test()
print("Starting accuracy = ", start_accuracy)

our_model.train()

end_accuracy = our_model.test()
print("Ending accuracy = ", end_accuracy)


device =  cuda:0
Starting accuracy =  9.569999694824219
Ending accuracy =  18.65999984741211
