In [1]:
import sys
import os

import torch
from torch import nn
from torch.utils.data import DataLoader
import numpy as np
from matplotlib import pyplot as plt
from tqdm import tqdm

from utils.options import args_parser
from utils.utils import exp_details, get_model
from utils.label_distribution_dataset import LabelDistributionDataset

ModuleNotFoundError: No module named 'sklearn'

In [None]:
args = args_parser(default=True)

args.local_ep = 5
args.lr = 0.00001
args.supervision = True

exp_details(args)

In [None]:

base_model = get_model(args.arch, args.dataset, "cpu")
base_model_path = os.path.join(args.data_path, f"base_model_{args.dataset}.pth")
torch.save(base_model, base_model_path)

In [None]:
train_dataset = LabelDistributionDataset(args.dataset, args.local_ep, args.local_bs, args.lr, args.optimizer, args.supervision, True, args.data_dir, base_model_path, "cuda")
test_dataset = LabelDistributionDataset(args.dataset, args.local_ep, args.local_bs, args.lr, args.optimizer, args.supervision, False, args.data_dir, base_model_path, "cuda")

In [None]:
print(f"Number of samples in train set: {len(train_dataset)}")
print(f"Number of samples in test set: {len(test_dataset)}")

train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=True)
for X, y in train_dataloader:
    print(f"Shape of X: {X.shape}")
    print(f"Shape of y: {y.shape}")
    break

In [None]:
class DistributionPredictor(nn.Module):
    def __init__(self):
        super().__init__()
        self.input_layer = nn.Linear(10, 1000)
        self.hidden_layers = nn.ModuleList()
        for i in range(8):
            self.hidden_layers.append(nn.Sequential(
                nn.Linear(1000, 1000),
                nn.ReLU(),
            ))
        self.output_layer = nn.Linear(1000, 10)
        self.relu = nn.ReLU()
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        x = self.input_layer(x)
        x = self.relu(x)
        for layer in self.hidden_layers:
            x = layer(x)
        x = self.output_layer(x)
        x = self.softmax(x)
        return x

In [None]:
def train(dataloader, model, loss_fn, optimizer, device):
    num_batches = len(dataloader)
    model.train()
    train_loss = 0
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        train_loss += loss
    train_loss /= num_batches
    return train_loss


def test(dataloader, model, loss_fn, device):
    num_batches = len(dataloader)
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
    test_loss /= num_batches
    return test_loss

In [None]:
device = "cuda"

model_save_path = "../save/distribution_predictor_model.pth"

model = DistributionPredictor()
model.to(device)
print(model)

loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=3e-5)

train_losses = []
test_losses = []

epochs = 500
for t in tqdm(range(epochs)):
    test_loss = test(test_dataloader, model, loss_fn, device)
    test_losses.append(test_loss)
    train_loss = train(train_dataloader, model, loss_fn, optimizer, device)
    train_losses.append(train_loss)

print("Done!")
print(f"Final test loss: {test_losses[-1]:.5f}")
print(f"Final train loss: {train_losses[-1]:.5f}")

torch.save(model, model_save_path)

In [None]:
plt.plot(train_losses, label="Train Loss")
plt.plot(test_losses, label="Test Loss")

plt.xlabel("Training Epoch")
plt.ylabel("Loss")
plt.title("Loss vs. Training Epoch")
plt.legend()
plt.show()


In [None]:
np.set_printoptions(formatter={'float': lambda x: "{0:0.2f}".format(x)})
model = torch.load(model_save_path)
model.eval()

with torch.no_grad():
    for X, y in test_dataloader:
        X, y = X.to(device), y.to(device)
        pred = model(X)
        for a, b in zip(pred.cpu().numpy(), y.cpu().numpy()):
            print(a, b)
        break
        

In [None]:
num_plots = 15

fig, axs = plt.subplots(num_plots, 2, figsize=(8, num_plots * 2))

cmap = plt.get_cmap('tab10')
colors = [cmap(i) for i in range(10)]

for i in range(num_plots):
    axs[i][0].bar(range(10), pred[i].cpu().numpy(), color=colors)
    axs[i][0].set_ylim(0, 1)
    axs[i][0].set_xticks([])
    axs[i][0].set_yticks([])
    axs[i][1].bar(range(10), y[i].cpu().numpy(), color=colors)
    axs[i][1].set_ylim(0, 1)
    axs[i][1].set_xticks([])
    axs[i][1].set_yticks([])
    
axs[0, 0].set_title("Prediction")
axs[0, 1].set_title("Ground Truth")

plt.savefig("Distribution Comparison CIFAR")