In [None]:
import torch
import argparse
from pathlib import Path
import time
import numpy as np
import os
from sklearn.linear_model import ElasticNet
import tqdm
import warnings
from sklearn.exceptions import ConvergenceWarning
from joblib import Parallel, delayed, cpu_count, dump, load
import random

device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
# dataset name is a value from the set {"imagenet", "places365", "cub", "cifar10", "cifar100"}
dataset = 'cifar10'

In [None]:
def dictionary_project(dictionary, init_dictionary, r):
    dictionary = dictionary / torch.norm(dictionary, p=2, dim=1, keepdim=True)
    diff = dictionary - init_dictionary
    diff_length = torch.norm(diff, p=2, dim=1, keepdim=True)
    idx = torch.nonzero(diff_length.squeeze() > r).squeeze()
    diff_dir = diff / diff_length
    diff[idx] = diff_dir[idx] * r
    dictionary = init_dictionary + diff
    dictionary = dictionary / torch.norm(dictionary, p=2, dim=1, keepdim=True)
    return dictionary

In [None]:
def dictionary_dispersion(dictionary, center, ratio):
    new_dictionary = np.zeros(dictionary.shape)
    for i in range(dictionary.shape[1]):
        vec = dictionary[:,i]
        angle = np.arccos(center.T @ vec)
        y_vec = vec - (center.T @ vec) * center
        y_vec = y_vec / np.linalg.norm(y_vec)
        new_dictionary[:,i] = np.cos(ratio * angle) * center + np.sin(ratio * angle) * y_vec
    return new_dictionary

In [None]:
class LogisticRegression(torch.nn.Module):
    # build the constructor
    def __init__(self, n_inputs, n_outputs):
        super(LogisticRegression, self).__init__()
        self.bn_1 = torch.nn.BatchNorm1d(n_inputs)
        self.dropout = torch.nn.Dropout(p=0.5)
        self.linear = torch.nn.Linear(n_inputs, n_outputs)

    # make predictions
    def forward(self, x):
        y_pred = self.linear(self.dropout(x))
        return y_pred

In [None]:
def get_model(dataset_name):
    if dataset_name == "cifar10":
        return LogisticRegression(n_inputs=128, n_outputs=10)

    elif dataset_name == "cifar100":
        return LogisticRegression(n_inputs=824, n_outputs=100)

    elif dataset_name == "cub":
        return LogisticRegression(n_inputs=208, n_outputs=200)

    elif dataset_name == "places365":
        return LogisticRegression(n_inputs=2207, n_outputs=365)

    elif dataset_name == "imagenet":
        return LogisticRegression(n_inputs=4523, n_outputs=1000)


In [None]:
dictionary = torch.load(f'ip_omp/saved_files/{dataset}_dictionary.pt')
dictionary = dictionary.cpu().numpy()

In [None]:
datay = torch.tensor(
    np.load(
        f"ip_omp/saved_files/{dataset}_train_embeddings.npy",
        mmap_mode="r",
    )
)

dataz = np.load(
        f"ip_omp/saved_files/{dataset}_train_labels.npy",
        mmap_mode="r",)

datay_test = torch.tensor(
    np.load(
        f"ip_omp/saved_files/{dataset}_test_embeddings.npy",
        mmap_mode="r",
    )
)

dataz_test = np.load(
        f"ip_omp/saved_files/{dataset}_test_labels.npy",
        mmap_mode="r",)

In [None]:
train_ds = torch.utils.data.TensorDataset(datay,torch.tensor(dataz))
test_ds = torch.utils.data.TensorDataset(datay_test,torch.tensor(dataz_test))

In [None]:
# Check if CUDA is available
if torch.cuda.is_available():
    # Get the number of GPUs
    num_gpus = torch.cuda.device_count()
    print(f"Number of GPUs available: {num_gpus}")
    # Loop through GPUs and print their type
    for i in range(num_gpus):
        print(f"GPU {i}: {torch.cuda.get_device_name(i)}")
else:
    print("No GPU available.")

## Constrained Concept Refinement 

In [None]:
#Atom Dispersion
mean = np.mean(dictionary, axis=1)
mean = mean / np.linalg.norm(mean)
cos_to_mean = mean.T @ dictionary
ratio = np.arccos(0)/np.arccos(min(cos_to_mean))
new_dictionary = dictionary_dispersion(dictionary, mean, ratio)

In [None]:
bs = 100
# Suggested threshold value for each datasets:
# CIFAR 10: 0.15
# CIFAR 100: 0.215
# CUB 200: 0.22
# ImageNet: 0.24
# Places365: 0.22
threshold = 0.15 
radius_bound = 0.1
d_lr = 1e-4
l_lr = 1e0
mmtm = 0.9
criterion = torch.nn.CrossEntropyLoss()

model_d = get_model(dataset)
model_d.to(device)
dictionary_d = torch.tensor(new_dictionary, dtype=torch.float32).to(device)
dictionary_d.requires_grad = True
ori_dictionary_d = dictionary_d.detach().clone()

optimizer_d = torch.optim.SGD([
            {'params': model_d.parameters()},
            {'params': dictionary_d, 'lr': d_lr}
        ], lr=l_lr, momentum=mmtm)

hard_thresh = torch.nn.Threshold(threshold,0)

In [None]:
ctr = 0
niter = 20

acc_log_d = np.zeros(niter)
spr_log_d = np.zeros(niter)
avg_log_d = np.zeros(niter)
max_log_d = np.zeros(niter)

hard_thresh = torch.nn.Threshold(threshold,0)

while True:
    # Train the model
    dataloader = torch.utils.data.DataLoader(train_ds, batch_size=bs, shuffle=True, num_workers=1)
    model_d.train()
    for data in tqdm.tqdm(dataloader):
        x, y = data

        x = x.to(device)

        y = y.to(device).long()

        optimizer_d.zero_grad()

        coeffs =  x @ dictionary_d
        coeffs = hard_thresh(coeffs)

        outputs = model_d(coeffs)

        loss = criterion(outputs, y)
        loss.backward()



        optimizer_d.step()
        dictionary_d.data = dictionary_project(dictionary_d.data.T, ori_dictionary_d.T, radius_bound).T

        
    loss = loss.item()
    
    dataloader = torch.utils.data.DataLoader(test_ds, batch_size=bs, shuffle=False, num_workers=1)
    model_d.eval()
    correct = 0
    sparsity = 0
    for data in tqdm.tqdm(dataloader):
        x,y = data
        x = x.to(device)
        y = y.to(device).long()

        coeffs =  x @ dictionary_d
        coeffs = hard_thresh(coeffs)
        sparsity += (torch.abs(coeffs) > 1e-4).sum()
        outputs = model_d(coeffs)

        predicted = torch.argmax(outputs.data, 1)
        correct += (predicted == y).sum()
    acc = 100 * (correct.item()) / len(test_ds)
    spr =  sparsity.item() / len(test_ds)

    with torch.no_grad():
        column_norms = torch.norm(dictionary_d - ori_dictionary_d, p=2, dim=0)
        average_norm = torch.mean(column_norms)
        largest_norm = torch.max(column_norms)
    print("Epoch:", ctr, "Train Loss:", loss, "Test accuracy:", acc)
    print("Sparsity:", spr)
    print("Average column deviation:",average_norm.cpu().numpy(),"Maximum column deviation:",largest_norm.cpu().numpy())
    
    acc_log_d[ctr] = acc
    spr_log_d[ctr] = spr
    avg_log_d[ctr] = average_norm
    max_log_d[ctr] = largest_norm
    ctr += 1

    if ctr >= niter:
        break