In [None]:
import matplotlib
from matplotlib import pyplot as plt
import numpy as np
import torch
import math
import networkx as nx
from tqdm import tqdm
from PIL import Image
import json
import copy
from torch.utils.data import DataLoader, random_split

from expbasics.helper import get_attributions, get_model_etc, to_name
from expbasics.network import load_model, train_network, accuracy_per_class
from expbasics.biased_noisy_dataset import get_biased_loader, BiasedNoisyDataset

%reload_ext autoreload
%autoreload 2

In [None]:
BIASES = list(np.round(np.linspace(0, 1, 51), 3))
#BIASES = list(np.round(np.linspace(0, 1, 21), 3))
SEEDS = [9, 5, 15, 14, 6, 1]

NAME = "../clustermodels/test"
rand_gen = torch.Generator().manual_seed(431)

BATCH_SIZE = 8
LEARNING_RATE = 0.001
EPOCHS = 5
STRENGTH = 0.5
IMAGE_PATH = "../dsprites-dataset/images"
accs = []
for bias in BIASES:
    for seed in SEEDS:
        ds = BiasedNoisyDataset(bias, STRENGTH, img_path=IMAGE_PATH)
        trainds, testds, _ = random_split(ds, [0.01, 0.005, 0.985], generator=rand_gen)
        train_loader = DataLoader(trainds, batch_size=BATCH_SIZE, shuffle=True, generator=rand_gen)
        test_loader = DataLoader(testds, batch_size=BATCH_SIZE, shuffle=True, generator=rand_gen)
        name = to_name(bias, seed)
        model = train_network(
            train_loader,
            bias,
            STRENGTH,
            NAME,
            BATCH_SIZE,
            load=False,
            retrain=False,
            learning_rate=LEARNING_RATE,
            epochs=EPOCHS,
            num_it=seed,
            seeded=True,
            disable=True
        )
        acc = list(accuracy_per_class(model, test_loader))
        print(acc)
        accs.append([bias, seed, acc])