In [13]:
from sklearn.ensemble import RandomForestClassifier
import numpy as np
from sklearn.metrics import classification_report
import os
from natsort import natsorted
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.utils.data as data
import torch.optim
# this script is to be used after 
# see https://github.com/yatindandi/Disentangled-Sequential-Autoencoder/blob/master/classifier.py

In [14]:
# set filenames and parameters etc.
root = './codes/release_test/'
nc = 3
seq_len = 20
imsize = 64
test_frac = 0.3
files = natsorted(os.listdir(root))

In [15]:
# import data and create train test splits etc.
id_codes_files = []
action_codes_files = []
labels_files = []
recon_files = []
test_img_files = []
sampled_img_files = []
for file in files:
    if 'dynamics' in file and 'label' not in file:
        action_codes_files.append(file)
    elif 'id' in file and 'label' not in file:
        id_codes_files.append(file)
    elif 'labels' in file:
        labels_files.append(file)
    elif 'recon' in file:
        recon_files.append(file)
    elif 'test_images' in file:
        test_img_files.append(file)
    elif 'sampled_images' in file:
        sampled_img_files.append(file)
         
id_codes = []
action_codes = []
labels = []
recon_gen = []
test_imgs = []
sampled_imgs = []

for file in id_codes_files:
    id_codes.append(np.load(os.path.join(root, model, file))['arr_0'])
for file in action_codes_files:
    action_codes.append(np.load(os.path.join(root, model, file))['arr_0'])
np_load_old = np.load
np.load = lambda *a,**k: np_load_old(*a, allow_pickle=True, **k)
for file in labels_files:
    labels.append(np.load(os.path.join(root, model, file))['arr_0'])
for file in test_img_files:
    test_imgs.append(np.load(os.path.join(root,model,file))['arr_0'])
for file in recon_files:
    recon_gen.append(np.load(os.path.join(root,model,file))['arr_0'])
for file in sampled_img_files:
    sampled_imgs.append(np.load(os.path.join(root,model,file))['arr_0'])
np.load = np_load_old
    
id_codes = np.asarray(id_codes)
print('id', id_codes.shape, id_codes[1,0,0,0])
id_codes = id_codes.reshape(-1, id_codes.shape[-1])
print(id_codes.shape, id_codes[10,0])
action_codes = np.asarray(action_codes)
print('act', action_codes.shape, action_codes[1,0,0])
action_codes = action_codes.reshape(-1, action_codes.shape[-1])
print(action_codes.shape, action_codes[6,0])
recon_gen = np.asarray(recon_gen)
print('recon', recon_gen.shape, recon_gen[1,0,0,0,0,0])
recon_gen = recon_gen.reshape(recon_gen.shape[0]*recon_gen.shape[1], seq_len, nc, imsize, imsize)        
print(recon_gen.shape, recon_gen[10, 0,0,0,0])
test_imgs = np.asarray(test_imgs)
test_imgs = test_imgs.reshape(test_imgs.shape[0]*test_imgs.shape[1], seq_len, nc, imsize, imsize)
sampled_imgs = np.asarray(sampled_imgs)
sampled_imgs = sampled_imgs.reshape(sampled_imgs.shape[0]*sampled_imgs.shape[1], seq_len, nc, imsize, imsize)

ids = []
actions = []
for i in range(len(labels)):
    num_in_batch = len(labels[0])
    
    for j in range(num_in_batch):
        ids.append(labels[i][j][0])
        actions.append(labels[i][j][1])
        
        
ids = np.asarray(ids).reshape(-1, 1)
actions = np.asarray(actions).reshape(-1, 1)
labels = np.concatenate((ids, actions), 1)

index = np.random.randint(len(labels), size=len(labels))
test_indices = index[:int(len(labels)*test_frac)]
train_indices = index[int(len(labels)*test_frac):]

real_imgs_tr = test_imgs[train_indices]
real_imgs_te = test_imgs[test_indices]
gen_imgs_tr = recon_gen[train_indices]
gen_imgs_te = recon_gen[test_indices]
labels_tr = labels[train_indices]
labels_te = labels[test_indices]

action_codes_tr = action_codes[train_indices]
action_codes_te = action_codes[test_indices]
id_codes_tr = id_codes[train_indices]
id_codes_te = id_codes[test_indices]



print(real_imgs_tr.shape, real_imgs_te.shape)
print(action_codes_tr.shape, action_codes_te.shape)
print(id_codes_tr.shape, id_codes_te.shape)
real_fake_train = np.concatenate((real_imgs_tr, gen_imgs_tr),0)
real_fake_test = np.concatenate((real_imgs_te, gen_imgs_te),0)
labels_both_train = np.concatenate((labels_tr, labels_tr), 0)
labels_both_te = np.concatenate((labels_te, labels_te), 0)

id (200, 10, 1, 15) 0.022323437
(2000, 15) 0.022323437
act (200, 10, 50) -0.0013404469
(2000, 50) -0.0014967673
recon (200, 10, 20, 3, 64, 64) 1.2447744e-05
(2000, 20, 3, 64, 64) 1.2447744e-05


In [18]:
# Dataset objects

class MUG(data.Dataset):
    def __init__(self, imgs, labels):
        self.imgs = imgs
        self.labels = labels
        
    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, idx):
        imgs = self.imgs[idx]
        labels = self.labels[idx]
        return imgs, labels

class Codes(data.Dataset):
    def __init__(self, codes, labels):
        self.codes = codes
        self.labels = labels
        
    def __len__(self):
        return len(self.codes)

    def __getitem__(self, idx):
        codes = self.codes[idx]
        labels = self.labels[idx]
        return codes, labels

(1400, 20, 3, 64, 64) (600, 20, 3, 64, 64)
(1400, 50) (600, 50)
(1400, 15) (600, 15)


In [20]:

# define classifier model and metrics
class MUGClassifier(nn.Module):
    def __init__(self, n_ids=52, n_actions=9,
                 num_frames=seq_len, in_size=64, channels=64, code_dim=1024, hidden_dim=512, nonlinearity=None):
        super(MUGClassifier, self).__init__()
        nl = nn.LeakyReLU(0.2) if nonlinearity is None else nonlinearity
        encoding_conv = []
        encoding_conv.append(nn.Sequential(nn.Conv2d(3, channels, 5, 4, 1, bias=False), nl))
        size = in_size // 4
        self.num_frames = num_frames
        while size > 4:
            encoding_conv.append(nn.Sequential(
                nn.Conv2d(channels, channels * 2, 5, 4, 1, bias=False),
                nn.BatchNorm2d(channels * 2), nl))
            size = size // 4
            channels *= 2
        self.encoding_conv = nn.Sequential(*encoding_conv)
        self.final_size = size
        self.final_channels = channels
        self.code_dim = code_dim
        self.hidden_dim = hidden_dim
        self.encoding_fc = nn.Sequential(
                nn.Linear(size * size * channels, code_dim),
                nn.BatchNorm1d(code_dim), nl)
        # The last hidden state of a convolutional LSTM over the scenes is used for classification
        self.classifier_lstm = nn.LSTM(code_dim, hidden_dim, batch_first=True, bidirectional=False)
        self.id = nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim // 2),
                nn.BatchNorm1d(hidden_dim // 2), nl,
                nn.Linear(hidden_dim // 2, n_ids))
        self.action = nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim // 2),
                nn.BatchNorm1d(hidden_dim // 2), nl,
                nn.Linear(hidden_dim // 2, n_actions))

    def forward(self, x):
        x = x.view(-1, x.size(2), x.size(3), x.size(4))
        x = self.encoding_conv(x)
        x = x.view(-1, self.final_channels * (self.final_size ** 2))
        x = self.encoding_fc(x)
        x = x.view(-1, self.num_frames, self.code_dim)
        # Classifier output depends on last layer of LSTM: Can also change this to a bi-LSTM if required
        _, (hidden, _) = self.classifier_lstm(x)
        hidden = hidden.view(-1, self.hidden_dim)
        return self.id(hidden), self.action(hidden)

def check_accuracy(model, test_data, device):
    total = 0
    correct_id = 0
    correct_action = 0
    with torch.no_grad():
        for item in test_data:
            image, label = item
            image = image.to(device)
            id_ = label[:, 0].to(device)
            action = label[:, 1].to(device)
            pred_id, pred_action = model(image)
            _, pred_id = torch.max(pred_id.data, 1)
            _, pred_action = torch.max(pred_action.data, 1)
            total += id_.size(0)
            correct_id += (pred_id == id_).sum().item()
            correct_action += (pred_action == action).sum().item()
    print('Accuracy, id : {}  Action {}'.format(correct_id/total, correct_action/total)) 


def train_classifier(model, optim, train_data, device, epochs, path, test_data, start=0):
    model.train()
    criterion = nn.CrossEntropyLoss()
    for epoch in range(start, epochs):
        running_loss = 0.0
        for i, item in enumerate(train_data, 1):
            image, label = item
            image = image.to(device)
            id_ = label[:, 0].to(device)
            action = label[:, 1].to(device)
            pred_id, pred_action = model(image)
            loss = criterion(pred_id, id_) + criterion(pred_action, action)
            loss.backward()
            optim.step()
            running_loss += loss.item()
        print('Epoch {} Avg Loss {}'.format(epoch + 1, running_loss / i))
#         save_model(model, optim, epoch, path)
        check_accuracy(model, test_data, device)
     

In [21]:
# init model and train
device = torch.device('cuda:0')
model = MUGClassifier()
model.to(device)
optim = torch.optim.Adam(model.parameters(), lr=0.0003)

mug_train = MUG(real_fake_train, labels_both_train)
# sprites_test = Sprites(real_fake_test, labels_both_te)
# sprites_train = Sprites(real_imgs_tr, labels_tr)
mug_test = MUG(gen_imgs_te, labels_te)
loader = data.DataLoader(mug_train, batch_size=32, shuffle=True, num_workers=4)
loader_test = data.DataLoader(mug_test, batch_size=64, shuffle=True, num_workers=4)
train_classifier(model, optim, loader, device, 50, './checkpoint_classifier.pth', loader_test) 

Epoch 1 Avg Loss 2.3598535643382506
Accuracy, id : 0.895  Action 0.6233333333333333
Epoch 2 Avg Loss 1.1181245690042323
Accuracy, id : 0.9566666666666667  Action 0.765
Epoch 3 Avg Loss 0.7033617293292825
Accuracy, id : 0.98  Action 0.7966666666666666
Epoch 4 Avg Loss 0.5556173095987602
Accuracy, id : 0.9633333333333334  Action 0.8233333333333334
Epoch 5 Avg Loss 0.48750003977594053
Accuracy, id : 0.9716666666666667  Action 0.8516666666666667
Epoch 6 Avg Loss 0.4349214125593955
Accuracy, id : 0.995  Action 0.88
Epoch 7 Avg Loss 0.38797594217414205
Accuracy, id : 0.985  Action 0.8766666666666667
Epoch 8 Avg Loss 0.38943881168961525
Accuracy, id : 0.99  Action 0.9016666666666666
Epoch 9 Avg Loss 0.3449688679327003
Accuracy, id : 0.9933333333333333  Action 0.8816666666666667
Epoch 10 Avg Loss 0.3073172631927512
Accuracy, id : 0.9916666666666667  Action 0.8983333333333333
Epoch 11 Avg Loss 0.2966213844377886
Accuracy, id : 0.99  Action 0.9416666666666667
Epoch 12 Avg Loss 0.3104266249468889

In [None]:

# compare real image preds against fake image preds
bs = 2
real_test = MUG(real_imgs_te, labels_te)
fake_test = MUG(gen_imgs_te, labels_te)
real_loader = data.DataLoader(real_test, batch_size=bs, shuffle=False, num_workers=4)
fake_loader = data.DataLoader(fake_test, batch_size=bs, shuffle=False, num_workers=4)

total = 0
correct_id = 0
correct_action = 0
model.eval()
for i, (item_real, item_fake) in enumerate(zip(real_loader, fake_loader)):
    
    if i % 10 == 0:
        print(i)
    fake_image, _ = item_fake
    fake_image = fake_image.to(device)
    image, _ = item_real
    image = image.to(device)
    im_check = image[0,0].permute(1,2,0).detach().cpu().numpy()
    im_check_fake = fake_image[0,0].permute(1,2,0).detach().cpu().numpy()

    # real
    pred_id, pred_action = model(image)
    _, pred_id = torch.max(pred_id.data, 1)
    _, pred_action = torch.max(pred_action.data, 1)
    # fake
    fake_pred_id, fake_pred_action = model(fake_image)
    _, fake_pred_id = torch.max(fake_pred_id.data, 1)
    _, fake_pred_action = torch.max(fake_pred_action.data, 1)
    total += pred_id.size(0)
    correct_id += (pred_id == fake_pred_id).sum().item()
    correct_action += (pred_action == fake_pred_action).sum().item()
print('Accuracy, id : {} Action {}'.format(correct_id/total, correct_action/total)) 

  

In [73]:
# INTER/INTRA ENTROPY AND IS KL

dummy_labels = torch.zeros(len(sampled_imgs), 2)
# compare real image preds against fake image preds
bs = 1
sampled_test = MUG(sampled_imgs, dummy_labels)  # here labels_te is a dummy 
sampled_loader = data.DataLoader(sampled_test, batch_size=bs, shuffle=False, num_workers=4)

pred_ids = []
pred_actions = []
model.eval()
sm = nn.Softmax(dim=1)
for i, item_sampled in enumerate(sampled_loader):
    
    if i % 10 == 0:
        print(i)
    sampled_image, _ = item_sampled
    sampled_image = sampled_image.to(device)
    # real
    pred_id, pred_action = model(sampled_image)
    pred_ids.append(sm(pred_id).detach().cpu().numpy())
    pred_actions.append(sm(pred_action).detach().cpu().numpy())

pred_actions = np.asarray(pred_actions)[:,0]
pred_ids = np.asarray(pred_ids)[:,0]
eps = 1e-16

pred_actions_y = pred_actions.mean(axis=0)
inter_entropy_actions = -(pred_actions_y * np.log(pred_actions_y + eps)).sum(0)

pred_ids_y = pred_ids.mean(axis=0)
inter_entropy_ids = -(pred_ids_y * np.log(pred_ids_y + eps)).sum(0)

intra_entropy_actions = -(pred_actions * np.log(pred_actions + eps)).sum(1).mean()
intra_entropy_ids = -(pred_ids * np.log(pred_ids + eps)).sum(1).mean()

kl_ids = (pred_ids * (np.log(pred_ids + eps) - np.log(pred_ids_y + eps))).sum(1).mean()
kl_actions = (pred_actions * (np.log(pred_actions + eps) - np.log(pred_actions_y + eps))).sum(1).mean()

average_inter_entropy = (inter_entropy_ids + inter_entropy_actions)/2
average_intra_entropy = (intra_entropy_actions + intra_entropy_ids)/2
average_kl = (kl_ids + kl_actions)/2
print('av. inter_ent', average_inter_entropy, 'av intra ent', average_intra_entropy, 'av kl', average_kl)

0
10
20
30
40
50
60
70
80
90
100
110
120
130
140
150
160
170
180
190
200
210
220
230
240
250
260
270
280
290
300
310
320
330
340
350
360
370
380
390
400
410
420
430
440
450
460
470
480
490
500
510
520
530
540
550
560
570
580
590
600
610
620
630
640
650
660
670
680
690
700
710
720
730
740
750
760
770
780
790
800
810
820
830
840
850
860
870
880
890
900
910
920
930
940
950
960
970
980
990
1000
1010
1020
1030
1040
1050
1060
1070
1080
1090
1100
1110
1120
1130
1140
1150
1160
1170
1180
1190
1200
1210
1220
1230
1240
1250
1260
1270
1280
1290
1300
1310
1320
1330
1340
1350
1360
1370
1380
1390
1400
1410
1420
1430
1440
1450
1460
1470
1480
1490
1500
1510
1520
1530
1540
1550
1560
1570
1580
1590
1600
1610
1620
1630
1640
1650
1660
1670
1680
1690
1700
1710
1720
1730
1740
1750
1760
1770
1780
1790
1800
1810
1820
1830
1840
1850
1860
1870
1880
1890
1900
1910
1920
1930
1940
1950
1960
1970
1980
1990


In [None]:
# test disentanglement by comparing prediction using id and action codes

class CodeClassifier(nn.Module):
    def __init__(self, n_ids=52, n_actions=9,
                  in_size=15, code_dim=15, hidden_dim=256, nonlinearity=None):
        super(CodeClassifier, self).__init__()
        nl = nn.LeakyReLU(0.2) if nonlinearity is None else nonlinearity

        self.code_dim = code_dim
        self.hidden_dim = hidden_dim
        self.encoding_fc = nn.Sequential(
                nn.Linear(in_size, hidden_dim),
                nn.BatchNorm1d(hidden_dim), nl)
        
        self.id = nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim // 2),
                nn.BatchNorm1d(hidden_dim // 2), nl,
                nn.Linear(hidden_dim // 2, n_ids))
        self.action = nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim // 2),
                nn.BatchNorm1d(hidden_dim // 2), nl,
                nn.Linear(hidden_dim // 2, n_actions))

    def forward(self, x):
        x = self.encoding_fc(x)
        return self.id(x), self.action(x)
       
    
device = torch.device('cuda:0')
model = CodeClassifier(in_size=15)
model.to(device)
optim = torch.optim.Adam(model.parameters(), lr=0.0003)

mug_train = Codes(id_codes_tr, labels_tr)
# sprites_test = Sprites(real_fake_test, labels_both_te)
# sprites_train = Sprites(real_imgs_tr, labels_tr)
mug_test = Codes(id_codes_te, labels_te)
loader = data.DataLoader(mug_train, batch_size=32, shuffle=True, num_workers=4)
loader_test = data.DataLoader(mug_test, batch_size=32, shuffle=True, num_workers=4)
train_classifier(model, optim, loader, device, 50, './checkpoint_classifier.pth', loader_test) 

device = torch.device('cuda:0')
model = CodeClassifier(in_size=10)
model.to(device)
optim = torch.optim.Adam(model.parameters(), lr=0.0004)

mug_train = Codes(action_codes_tr, labels_tr)

mug_test = Codes(action_codes_te, labels_te)
loader = data.DataLoader(mug_train, batch_size=32, shuffle=True, num_workers=4)
loader_test = data.DataLoader(mug_test, batch_size=32, shuffle=True, num_workers=4)
train_classifier(model, optim, loader, device, 100, './checkpoint_classifier.pth', loader_test) 