In [9]:
import tqdm
import itertools
#PyTorch
import torch

In [7]:
import sys
sys.path.append("../src/")

%load_ext autoreload
%autoreload 2
# Importing our custom module(s)
import layers
import likelihoods
import losses
import priors
import utils

In [2]:
cifar10_dir = "/cluster/tufts/hugheslab/eharve06/CIFAR-10"
cifar101_v4_dir = "/cluster/tufts/hugheslab/eharve06/CIFAR-10.1"
ns = [100, 1_000, 10_000, 50_000]
random_states = [1001, 2001, 3001]
repo_dir = "/cluster/tufts/hugheslab/eharve06/random-Fourier-features"

for n, random_state in itertools.product(ns, random_states):
    encoded_path = f"{repo_dir}/datasets/CIFAR-10/n={n}_random_state={random_state}.pth"
    print(
        f"    \'python ../src/encode_cifar10.py "
        f"--batch_size=128 "
        f"--cifar10_dir=\"{cifar10_dir}\" "
        f"--cifar101_v4_dir=\"{cifar101_v4_dir}\" "
        f"--encoded_path=\"{encoded_path}\" "
        f"--n={n} "
        f"--num_workers=0 "
        f"--random_state={random_state}\'"
    )


    'python ../src/encode_cifar10.py --batch_size=128 --cifar10_dir="/cluster/tufts/hugheslab/eharve06/CIFAR-10" --cifar101_v4_dir="/cluster/tufts/hugheslab/eharve06/CIFAR-10.1" --encoded_path="/cluster/tufts/hugheslab/eharve06/random-Fourier-features/datasets/CIFAR-10/n=100_random_state=1001.pth" --n=100 --num_workers=0 --random_state=1001'
    'python ../src/encode_cifar10.py --batch_size=128 --cifar10_dir="/cluster/tufts/hugheslab/eharve06/CIFAR-10" --cifar101_v4_dir="/cluster/tufts/hugheslab/eharve06/CIFAR-10.1" --encoded_path="/cluster/tufts/hugheslab/eharve06/random-Fourier-features/datasets/CIFAR-10/n=100_random_state=2001.pth" --n=100 --num_workers=0 --random_state=2001'
    'python ../src/encode_cifar10.py --batch_size=128 --cifar10_dir="/cluster/tufts/hugheslab/eharve06/CIFAR-10" --cifar101_v4_dir="/cluster/tufts/hugheslab/eharve06/CIFAR-10.1" --encoded_path="/cluster/tufts/hugheslab/eharve06/random-Fourier-features/datasets/CIFAR-10/n=100_random_state=3001.pth" --n=100 --num

In [4]:
n, random_state = 1000, 1001
datasets = torch.load(f"{repo_dir}/datasets/CIFAR-10/n={n}_random_state={random_state}.pth", map_location=torch.device("cpu"), weights_only=False)

train_dataset = torch.utils.data.TensorDataset(datasets["X_train"], datasets["y_train"])
val_dataset = torch.utils.data.TensorDataset(datasets["X_val"], datasets["y_val"])
test_dataset = torch.utils.data.TensorDataset(datasets["X_test"], datasets["y_test"])
ood_dataset = torch.utils.data.TensorDataset(datasets["X_ood"], datasets["y_ood"])

batch_size = 128
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size)
ood_dataloader = torch.utils.data.DataLoader(ood_dataset, batch_size=batch_size)


In [5]:
def train_one_epoch(model, criterion, optimizer, dataloader):
    
    model.train()
    
    running_loss = 0.0
    for X_batch, y_batch in dataloader:
        
        batch_size = len(X_batch)
        
        optimizer.zero_grad()
        params = utils.flatten_params(model)
        logits = model(X_batch)
        loss = criterion(logits, y_batch, params, len(dataloader.dataset))
        running_loss += (1 / batch_size) * loss.item()
        loss.backward()
        optimizer.step()
        
    return running_loss
        
def evaluate(model, criterion, dataloader):
    
    model.eval()
    
    with torch.no_grad():
        
        running_loss = 0.0
        for X_batch, y_batch in dataloader:

            batch_size = len(X_batch)
            
            params = utils.flatten_params(model)
            logits = model(X_batch)
            loss = criterion(logits, y_batch, params, len(dataloader.dataset))
            running_loss += (1 / batch_size) * loss.item()
            
    return running_loss


In [15]:
model = layers.RFFLaplace(in_features=2048, out_features=10, rank=1024, lengthscale=20.0, outputscale=1.0)
likelihood = likelihoods.CategoricalLikelihood()
prior = priors.GaussianPrior()

map_criterion = losses.MAPLoss(likelihood, prior)
erm_criterion = losses.ERMLoss(likelihood)

state_dict = {
    "model": model.state_dict(),
    "likelihood": likelihood.state_dict(),
    "prior": prior.state_dict(),
}

best_val_loss = float("inf")
best_state_dict = None

for lr in [0.1, 0.01, 0.001, 0.0001]:

    model.load_state_dict(state_dict["model"])
    likelihood.load_state_dict(state_dict["likelihood"])
    prior.load_state_dict(state_dict["prior"])
    
    optimizer = torch.optim.Adam([{"params": model.parameters()}], lr=lr, weight_decay=0.0)
    
    epochs = 1_000
    
    for epoch in tqdm.tqdm(range(epochs)):
        
        train_loss = train_one_epoch(model, map_criterion, optimizer, train_dataloader)
        val_loss = evaluate(model, erm_criterion, val_dataloader)
    
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_state_dict = {
            "model": model.state_dict(),
            "likelihood": likelihood.state_dict(),
            "prior": prior.state_dict(),
        }
        
model.load_state_dict(best_state_dict["model"])
likelihood.load_state_dict(best_state_dict["likelihood"])
prior.load_state_dict(best_state_dict["prior"])


100%|██████████| 1000/1000 [00:37<00:00, 26.50it/s]
100%|██████████| 1000/1000 [00:37<00:00, 26.33it/s]
100%|██████████| 1000/1000 [00:38<00:00, 26.10it/s]
100%|██████████| 1000/1000 [00:37<00:00, 26.51it/s]


<All keys matched successfully>

In [16]:
with torch.no_grad():
    
    model.update_covariance_from_dataloader(train_dataloader)

    test_probs = model.predict_proba(datasets["X_test"], num_samples=10_000)
    ood_probs = model.predict_proba(datasets["X_ood"], num_samples=10_000)
    

In [33]:
with torch.no_grad():

    num_classes = 10

    test_preds = torch.argmax(test_probs, dim=1)
    ood_preds = torch.argmax(ood_probs, dim=1)

    test_acc = torch.tensor([(test_preds[datasets["y_test"] == c] == c).float().mean() for c in range(num_classes)]).mean()
    ood_acc = torch.tensor([(ood_preds[datasets["y_ood"] == c] == c).float().mean() for c in range(num_classes)]).mean()
    print(test_acc)
    print(ood_acc)


tensor(0.8180)
tensor(0.6946)
