In [1]:
import os
import pickle

import resnet
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import torch
import torch.nn as nn
import torch.utils as tutils
from torch.nn import functional as F

import torchvision.transforms as tv_transforms
import torchvision.utils as vutils
import torchvision.datasets as tv_datasets
import matplotlib.pyplot as plt

from sklearn.metrics.pairwise import cosine_similarity
# from gen_model import FeaturesGenerator, InferenceQYZ
from sklearn import preprocessing

device = "cuda:0" if torch.cuda.is_available() else "cpu"
print(f"Using {device}")

Using cuda:0


In [None]:
class FeaturesGenerator(torch.nn.Module):
    def __init__(self, latent_dim, num_k, out_dim):
        self.num_nodes = latent_dim + num_k
        super(FeaturesGenerator, self).__init__()
        self.main = nn.Sequential(nn.Linear(self.num_nodes, 512),
                                  nn.LeakyReLU(0.2, True),
                                  nn.Linear(512, out_dim),
#                                   nn.ReLU(True),
                                  nn.Sigmoid()
                                  )
        
    def forward(self, x, y):
        in_vec = torch.cat([x, y], dim=1)
        out = self.main(in_vec)
        return out


class LinearCLS(nn.Module):
    def __init__(self, input_dim, nclass):
        super(LinearCLS, self).__init__()
        self.fc = nn.Linear(input_dim, nclass)
        self.logic = nn.LogSoftmax(dim=1)
    def forward(self, x): 
        o = self.logic(self.fc(x))
        return o

# One-Time Feature Extraction

Using Resnet56 from, https://github.com/akamaster/pytorch_resnet_cifar10

In [13]:
model = nn.DataParallel(resnet.resnet56())
checkpoint = torch.load("pretrained_weights/resnet56-4bfd9763.th")            
model.load_state_dict(checkpoint['state_dict'])
model = model.to(device)
# model.module.linear = nn.Identity()
# fc_cls = model.module.linear

In [14]:
def read_pickle(data_path, file_list):
    data = []
    targets = []
    # now load the picked numpy arrays
    for file_name, checksum in file_list:
        file_path = f"{data_path}/{file_name}"
        with open(file_path, 'rb') as f:
            entry = pickle.load(f, encoding='latin1')
            data.append(entry['data'])
            if 'labels' in entry:
                targets.extend(entry['labels'])
            else:
                targets.extend(entry['fine_labels'])
    data = np.vstack(data).reshape(-1, 3, 32, 32)
    data = data.transpose((0, 2, 3, 1))  # convert to HWC
    return data, targets


train_list = [
    ['data_batch_1', 'c99cafc152244af753f735de768cd75f'],
    ['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'],
    ['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'],
    ['data_batch_4', '634d18415352ddfa80567beed471001a'],
    ['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'],
]
test_list = [
    ['test_batch', '40351d587109b95175f43aff81a1287e'],
]

x_train, y_train = read_pickle("data/cifar-10-batches-py/",train_list)
x_valid, y_valid = read_pickle("data/cifar-10-batches-py/",test_list)

x_train = torch.from_numpy(x_train.astype("float32")).permute(0,3,1,2)
x_valid = torch.from_numpy(x_valid.astype("float32")).permute(0,3,1,2)

y_train = np.array(y_train)
y_valid = np.array(y_valid)

In [15]:
img_trans = tv_transforms.Compose([
                                   tv_transforms.ToPILImage(),
                                   tv_transforms.Resize(32),
                                   tv_transforms.ToTensor(),
                                   tv_transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                           std=[0.229, 0.224, 0.225])])

alst = []
for idx in range(x_train.shape[0]):
    img = img_trans(x_train[idx])
    alst.append(img.unsqueeze(0))
x_train = torch.cat(alst, dim=0)

alst = []
for idx in range(x_valid.shape[0]):
    img = img_trans(x_valid[idx])
    alst.append(img.unsqueeze(0))
x_valid = torch.cat(alst, dim=0)

In [21]:
mb_size = 128
num_iter = x_valid.size(0) // mb_size
bd_indices = [ii for ii in range(0, x_valid.size(0), mb_size)]
batch_indices = np.arange(x_valid.size(0))
corr = 0
with torch.no_grad():
    for idx, iter_count in enumerate(bd_indices):
        batch_idx = batch_indices[iter_count:iter_count+mb_size]
        x_mb = x_valid[batch_idx].float().to(device)
        y_mb = y_valid[batch_idx]
        output = model.module.linear(x_mb)
        pred = torch.argmax(output,dim=1)
        corr += torch.sum(torch.eq(y_mb, pred.cpu()))

acc = 100*corr / float(x_valid.size(0))
print(f"accuracy: {acc}")

accuracy: 93.38999938964844


In [13]:
mb_size = 128
model.eval()
train_feas = []
with torch.no_grad():
    for idx in np.arange(0, x_train.size(0), mb_size):
        x_mb = x_train[idx:idx+mb_size].to(device)
        feas_out = model.module.get_feas(x_mb)
        train_feas.append(feas_out.detach())

    train_feas = torch.cat(train_feas, dim=0)
    
test_feas = []
with torch.no_grad():
    for idx in np.arange(0, x_valid.size(0), mb_size):
        x_mb = x_valid[idx:idx+mb_size].to(device)
        feas_out = model.module.get_feas(x_mb)
        test_feas.append(feas_out.detach())

    test_feas = torch.cat(test_feas, dim=0)

In [14]:
torch.save({"x_train": train_feas.cpu().detach(),
            "x_valid": test_feas.cpu().detach(),
            "y_train": torch.from_numpy(y_train).long(),
            "y_valid": torch.from_numpy(y_valid).long()}, "data/cifar-10-batches-py/resnet56_midfeas_unnorm.tar")

# Load Extracted Features

In [19]:
checkpoint = torch.load("data/cifar-10-batches-py/resnet56_feas.tar")
x_train = checkpoint["x_train"].numpy()
x_valid = checkpoint["x_valid"].numpy()
# x_train = checkpoint["x_train"]
# x_valid = checkpoint["x_valid"]
y_train = checkpoint['y_train']
y_valid = checkpoint["y_valid"]

# scaler = preprocessing.MinMaxScaler()
# x_train = scaler.fit_transform(x_train)
# x_valid = scaler.transform(x_valid)
x_train = torch.from_numpy(x_train)
x_valid = torch.from_numpy(x_valid)

# Train Upper Bound Classifier

In [None]:
mb_size = 128
linear_cls = LinearCLS(64, 10).to(device)
optimizer_cls = torch.optim.Adam(linear_cls.parameters(), lr=0.001, betas=(0.5, 0.999))
cls_criterion = nn.NLLLoss()

num_iter = x_train.size(0) // mb_size
bd_indices = [ii for ii in range(0, x_train.size(0), mb_size)]

for epc in range(20):
    batch_indices = np.random.permutation(np.arange(x_train.size(0)))
    for idx, iter_count in enumerate(bd_indices):
        batch_idx = batch_indices[iter_count:iter_count+mb_size]
        x_mb = x_train[batch_idx].float().to(device)
        y_mb = y_train[batch_idx].to(device)
        output = linear_cls(x_mb)
        loss = cls_criterion(output, y_mb)
        optimizer_cls.zero_grad()
        loss.backward()
        optimizer_cls.step()


In [None]:
num_iter = x_valid.size(0) // mb_size
bd_indices = [ii for ii in range(0, x_valid.size(0), mb_size)]
batch_indices = np.arange(x_valid.size(0))
corr = 0
with torch.no_grad():
    for idx, iter_count in enumerate(bd_indices):
        batch_idx = batch_indices[iter_count:iter_count+mb_size]
        x_mb = x_valid[batch_idx].float().to(device)
        y_mb = y_valid[batch_idx].to(device)
        output = linear_cls(x_mb)
        pred = torch.argmax(output,dim=1)
        corr += torch.sum(torch.eq(y_mb, pred))

acc = 100*corr / float(x_valid.size(0))
print(f"accuracy: {acc}")

# Task 1 begins

In [None]:
task_mask_1 = y_train >= 4
task_mask_2 = y_train < 6
task_mask = task_mask_1 * task_mask_2

train_feas = x_train[task_mask]
train_label = y_train[task_mask]

task_mask_1 = y_valid >= 0
task_mask_2 = y_valid < 6
task_mask = task_mask_1 * task_mask_2
valid_feas = x_valid[task_mask]
valid_label = y_valid[task_mask]
print(f"train X: {train_feas.shape}")
print(f"train Y : {train_label.shape}")
print(f"valid X: {valid_feas.shape}")
print(f"valid Y: {valid_label.shape}")

# Replay

In [None]:
train_feas = torch.cat([gen_feat, train_feas], dim=0)
train_label = torch.cat([gen_label, train_label])
print(f"train X: {train_feas.shape}")
print(f"train Y : {train_label.shape}")
print(f"valid X: {valid_feas.shape}")
print(f"valid Y: {valid_label.shape}")

In [None]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
#         nn.init.kaiming_normal_(m.weight.data, 0.0, 0.02)
        nn.init.kaiming_normal_(m.weight.data, mode='fan_in', nonlinearity='relu')
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)


def get_recon_loss(pred, x, sigma):
    recon_loss = 1/(2*sigma**2) * torch.pow(x - pred, 2).sum()
    return recon_loss


def get_prior_loss(z, mus, logsigma):
    log_pdf = 0.5 * torch.sum(np.log(2.0 * np.pi) + logsigma + torch.pow(z - mus, 2) / torch.exp(logsigma))
    return log_pdf


def get_prior_loss_mm(z, mus, logsigma):
    dist_term = torch.pow(z.unsqueeze(1) - mus.unsqueeze(0), 2) / torch.exp(logsigma.unsqueeze(0))
    log_pdf = -0.5 * torch.sum(np.log(2.0 * np.pi) + logsigma.unsqueeze(0) + dist_term, dim=-1)  # (B, nc)
    return log_pdf


def get_entropy_loss(logits, probs):
    log_q = torch.log_softmax(logits, dim=1)
    return torch.sum(-torch.sum(probs * log_q, dim=-1))

In [None]:
latent_dim = 10
num_k = 10
mb_size = 64
lr_rate = 0.0002
weight_decay = 0.001
langevin_s = 0.3
langevin_steps = 30
sigma = 0.3

train_z = torch.FloatTensor(train_feas.shape[0], latent_dim).normal_(0,1).float().to(device)
mus = torch.FloatTensor(num_k, latent_dim).normal_(0,5).float().to(device)
mus.requires_grad_()
logsigma = torch.zeros(num_k, latent_dim).float().to(device)
logsigma.requires_grad_()
# pi_c = torch.ones(num_k)/num_k
# pi_c = pi_c.float().to(device)
# pi_c.requires_grad_()

netG = FeaturesGenerator(latent_dim, num_k, 64).to(device)
netG.apply(weights_init)

num_iter = train_feas.size(0) // mb_size
bd_indices = [ii for ii in range(0,train_feas.size(0), mb_size)]

In [None]:
train_z = torch.FloatTensor(train_feas.shape[0], latent_dim).normal_(0,1).float().to(device)
train_z[:replay_z.size(0)] = replay_z

### Load checkpoint (Optional)

In [None]:
checkpoint = torch.load('cifar10feas_task1.tar')

netG.load_state_dict(checkpoint["netG_state_dict"])
train_z = checkpoint['train_z_state'].to(device)
mus = checkpoint["mus"].to(device)
logsigma = checkpoint["logsigma"].to(device)

### Train Model

In [None]:
# optimizer_g = torch.optim.Adam(list(netG.parameters()) + list(netD.parameters()), lr=lr_rate, weight_decay=weight_decay, betas=(0.5,0.999))
optimizer_g = torch.optim.Adam(netG.parameters(), lr=lr_rate, weight_decay=weight_decay, betas=(0.5,0.999))
optimizer_g.add_param_group({"params": [mus, logsigma]})

for epc in range(20):
    batch_indices = np.random.permutation(np.arange(train_feas.size(0)))
    epc_loss = 0
    for idx, iter_count in enumerate(bd_indices):
        batch_idx = batch_indices[iter_count:iter_count+mb_size]
        x_mb = train_feas[batch_idx].float().to(device)
        y_mb = train_label[batch_idx]
        z_mb = train_z[batch_idx]
        z_mb.requires_grad_()
        optimizer_z = torch.optim.Adam([z_mb], lr=lr_rate, weight_decay=weight_decay, betas=(0.5,0.999))
        batch_loss = 0
        for em in range(2):
            optimizer_g.zero_grad()
            one_hot_y = torch.eye(num_k)[y_mb]
            recon_x = netG(z_mb, one_hot_y.to(device))
            recon_loss = get_recon_loss(recon_x, x_mb, sigma)  # Reconstruction Loss

            log_pdfs = get_prior_loss_mm(z_mb, mus, logsigma)
#             y_cat = torch.argmax(log_pdfs, dim=1).detach()
            yita_c = torch.exp(log_pdfs) + 1e-10
            yita_c = yita_c/torch.sum(yita_c, dim=1).view(-1,1)
            entropy_loss = get_entropy_loss(log_pdfs, one_hot_y.to(device))  # Entropy Loss
            
            prior_loss = get_prior_loss(z_mb, mus[y_mb], logsigma[y_mb])
#             prior_loss = 0.5*torch.sum(torch.pow(z_mb, 2))
            
            gloss = recon_loss + prior_loss + entropy_loss
            gloss /= x_mb.size(0)
            gloss.backward()
            optimizer_g.step()
            srmc_loss = 0
            for _ in range(langevin_steps):
                optimizer_z.zero_grad()
                u_tau = torch.randn(z_mb.size(0), latent_dim).float().to(device)

                one_hot_y = torch.eye(num_k)[y_mb]
                recon_x = netG(z_mb, one_hot_y.to(device))
                recon_loss = get_recon_loss(recon_x, x_mb, sigma)
                
                log_pdfs = get_prior_loss_mm(z_mb, mus, logsigma)
                y_cat = torch.argmax(log_pdfs, dim=1).detach()
                yita_c = torch.exp(log_pdfs) + 1e-10
                yita_c = yita_c/torch.sum(yita_c, dim=1).view(-1,1)
                entropy_loss = get_entropy_loss(log_pdfs, one_hot_y.to(device)) 
            
                prior_loss = get_prior_loss(z_mb, mus[y_mb], logsigma[y_mb])
#                 prior_loss = 0.5*torch.sum(torch.pow(z_mb, 2))
                
                loss = recon_loss + prior_loss + entropy_loss
                loss /= x_mb.size(0)
                loss = langevin_s**2/2*loss
                loss.backward()
                optimizer_z.step()
                z_mb.data += u_tau * langevin_s
                srmc_loss += loss.detach()
                
            train_z[batch_idx,] = z_mb.data
            batch_loss += (srmc_loss / langevin_steps) + gloss.detach()
        batch_loss /= 2
        epc_loss += batch_loss
    print(f"Epoch {epc+1} End; loss: {(epc_loss/(idx+1)): .4f}; recon: {recon_loss: .4f}; prior: {prior_loss: .4f}; entropy: {entropy_loss: .4f}")
#     print(f"Epoch {epc+1} End; loss: {(epc_loss/(idx+1)): .4f}")

### Save checkpoint

In [None]:
torch.save({'netG_state_dict': netG.state_dict(),
#             'netD_state_dict': netD.state_dict(),
            'train_z_state': train_z.cpu().detach(),
            'mus': mus.cpu().detach(),
            'logsigma':logsigma.cpu().detach()
            }, "cifar10feas_task1.tar")

### Evaluate Model 1 (Not run)

Short-run MC for the latent

In [None]:
mb_size = 64
test_size = valid_feas.size(0)
corr = 0
corr_1 = 0

# Maybe deep copy the weights
for batch_idx in range(0, test_size, mb_size):
    x_new = valid_feas[batch_idx:batch_idx+mb_size].to(device)
    y_new = valid_label[batch_idx:batch_idx+mb_size].to(device)

    z_samp = torch.randn(x_new.size(0),latent_dim).to(device)
    optim_new = torch.optim.Adam(netG.parameters(), lr=lr_rate, weight_decay=weight_decay, betas=(0.5,0.999))
    optim_new.add_param_group({"params": [z_samp]})
#     optim_new = torch.optim.Adam([z_samp], lr=lr_rate, weight_decay=weight_decay, betas=(0.5,0.999))

    for _ in range(50):
        recon_x = netG(z_samp, 1)
        recon_loss = get_recon_loss(recon_x, x_new, sigma)
        loss = (recon_loss)/y_new.shape[0]
        loss = langevin_s**2/2*loss
        optim_new.zero_grad()
        loss.backward()
        optim_new.step()
    recon_feas = netG(z_samp,1)
    pred_cls = torch.argmax(fc_cls(recon_feas), dim=1).detach()
    orig_pred = torch.argmax(fc_cls(x_new), dim=1).detach()
    num_correct = torch.sum(torch.eq(pred_cls, y_new))
    corr += num_correct
    num_correct = torch.sum(torch.eq(orig_pred, y_new))
    corr_1 += num_correct
    
acc = corr / float(test_size)
print(f"Inferential Backprop: {acc*100: .4f}")
acc = corr_1 / float(test_size)
print(f"Original weights: {acc*100: .4f}")

### Evaluate Model 2

Linear Classifier

In [None]:
# Generate training samples
nSample = 300
gen_feat = torch.FloatTensor(6 * nSample, 64)
gen_label = np.zeros([0])
replay_z = []
with torch.no_grad():
    for ii in range(6):
        one_hot_y = torch.eye(num_k)[ii]
        one_hot_y = one_hot_y.repeat(nSample, 1)
        z = torch.randn(nSample, latent_dim).to(device)
        G_sample = netG(z, one_hot_y.to(device))
        gen_feat[ii*nSample:(ii+1)*nSample] = G_sample
        gen_label = np.hstack((gen_label, np.ones([nSample])*ii))
        replay_z.append(z)

gen_label = torch.from_numpy(gen_label).long()
replay_z = torch.cat(replay_z, dim=0)

In [None]:
# Train the classifier
mb_size = 128
linear_cls = LinearCLS(64, 10).to(device)
optimizer_cls = torch.optim.Adam(linear_cls.parameters(), lr=0.001, betas=(0.5, 0.999))
cls_criterion = nn.NLLLoss()

num_iter = gen_feat.size(0) // mb_size
bd_indices = [ii for ii in range(0, gen_feat.size(0), mb_size)]

for epc in range(20):
    batch_indices = np.random.permutation(np.arange(gen_feat.size(0)))
    for idx, iter_count in enumerate(bd_indices):
        batch_idx = batch_indices[iter_count:iter_count+mb_size]
        x_mb = gen_feat[batch_idx].float().to(device)
        y_mb = gen_label[batch_idx].to(device)
        output = linear_cls(x_mb)
        loss = cls_criterion(output, y_mb)
        optimizer_cls.zero_grad()
        loss.backward()
        optimizer_cls.step()


In [None]:
# validation
num_iter = valid_feas.size(0) // mb_size
bd_indices = [ii for ii in range(0, valid_feas.size(0), mb_size)]
batch_indices = np.arange(valid_feas.size(0))
corr = 0
with torch.no_grad():
    for idx, iter_count in enumerate(bd_indices):
        batch_idx = batch_indices[iter_count:iter_count+mb_size]
        x_mb = valid_feas[batch_idx].float().to(device)
        y_mb = valid_label[batch_idx].to(device)
        output = linear_cls(x_mb)
        pred = torch.argmax(output,dim=1)
        corr += torch.sum(torch.eq(y_mb, pred))

acc = 100*corr / float(valid_feas.size(0))
print(f"accuracy: {acc}")

# Visualization

In [None]:
from sklearn.manifold import TSNE

latent_rep = train_z.cpu().numpy()
embed = TSNE(n_components=2).fit_transform(latent_rep)


In [None]:
plt.figure(figsize=(10,7))
mask0 = train_label.numpy() == 0
mask1 = train_label.numpy() == 1
plt.scatter(embed[mask0,0], embed[mask0,1])
plt.scatter(embed[mask1,0], embed[mask1,1])
plt.show()


In [None]:
mask0