In [None]:
from sklearn.ensemble import RandomForestClassifier
import numpy as np
from sklearn.metrics import classification_report
import os
from natsort import natsorted
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.utils.data as data
import torch.optim

# see https://github.com/yatindandi/Disentangled-Sequential-Autoencoder/blob/master/classifier.py

In [None]:
# set filename and parameters
root = './codes/release_test/'
nc = 3
seq_len = 8
imsize = 64
test_frac = 0.3
files = natsorted(os.listdir(root))

In [None]:
# load and preprocess the data
id_codes_files = []
action_codes_files = []
labels_files = []
recon_files = []
test_img_files = []

for file in files:
    if 'dynamics' in file and 'label' not in file:
        action_codes_files.append(file)
    elif 'id' in file and 'label' not in file:
        id_codes_files.append(file)
    elif 'labels' in file:
        labels_files.append(file)
    elif 'recon' in file:
        recon_files.append(file)
    elif 'test_images' in file:
        test_img_files.append(file)
        
         
id_codes = []
action_codes = []
labels = []
recon_gen = []
test_imgs = []

for file in id_codes_files:
    id_codes.append(np.load(os.path.join(root, model, file))['arr_0'])
for file in action_codes_files:
    action_codes.append(np.load(os.path.join(root, model, file))['arr_0'])
np_load_old = np.load
np.load = lambda *a,**k: np_load_old(*a, allow_pickle=True, **k)
for file in labels_files:
    labels.append(np.load(os.path.join(root, model, file))['arr_0'])
for file in test_img_files:
    test_imgs.append(np.load(os.path.join(root,model,file))['arr_0'])
for file in recon_files:
    recon_gen.append(np.load(os.path.join(root,model,file))['arr_0'])
np.load = np_load_old
    
id_codes = np.asarray(id_codes)
print(id_codes.shape, id_codes[1,0,0,0])
id_codes = id_codes.reshape(-1, id_codes.shape[-1])
print(id_codes.shape, id_codes[10,0])
action_codes = np.asarray(action_codes)
print(action_codes.shape, action_codes[1,0,0])
action_codes = action_codes.reshape(-1, action_codes.shape[-1])
print(action_codes.shape, action_codes[10,0])
recon_gen = np.asarray(recon_gen)
print(recon_gen.shape, recon_gen[1,0,0,0,0,0])
recon_gen = recon_gen.reshape(recon_gen.shape[0]*recon_gen.shape[1], seq_len, nc, imsize, imsize)        
print(recon_gen.shape, recon_gen[10,0,0,0,0])
test_imgs = np.asarray(test_imgs)
test_imgs = test_imgs.reshape(test_imgs.shape[0]*test_imgs.shape[1], seq_len, nc, imsize, imsize)

bodies = []
shirts = []
pants = []
hairs = []
actions = []
for i in range(len(labels)):
    num_in_batch = len(labels[0].item()['body'])
    
    for j in range(num_in_batch):
        bodies.append(labels[i].item()['body'][j])
        shirts.append(labels[i].item()['shirt'][j])
        pants.append(labels[i].item()['pant'][j])
        hairs.append(labels[i].item()['hair'][j])
        actions.append(labels[i].item()['action'][j])
bodies = np.asarray(bodies).reshape(-1, 1)
shirts = np.asarray(shirts).reshape(-1, 1)
pants = np.asarray(pants).reshape(-1, 1)
hairs = np.asarray(hairs).reshape(-1, 1)
actions = np.asarray(actions).reshape(-1, 1)
labels = np.concatenate((bodies, shirts, pants, hairs, actions), 1)

index = np.random.randint(len(labels), size=len(labels))
test_indices = index[:int(len(labels)*test_frac)]
train_indices = index[int(len(labels)*test_frac):]

real_imgs_tr = test_imgs[train_indices]
real_imgs_te = test_imgs[test_indices]
gen_imgs_tr = recon_gen[train_indices]
gen_imgs_te = recon_gen[test_indices]
labels_tr = labels[train_indices]
labels_te = labels[test_indices]
print(real_imgs_tr.shape, real_imgs_te.shape)

real_fake_train = np.concatenate((real_imgs_tr, gen_imgs_tr),0)
real_fake_test = np.concatenate((real_imgs_te, gen_imgs_te),0)
labels_both_train = np.concatenate((labels_tr, labels_tr), 0)
labels_both_te = np.concatenate((labels_te, labels_te), 0)

action_codes_tr = action_codes[train_indices]
action_codes_te = action_codes[test_indices]
id_codes_tr = id_codes[train_indices]
id_codes_te = id_codes[test_indices]

print(real_imgs_tr.shape, real_imgs_te.shape)
print(action_codes_tr.shape, action_codes_te.shape)
print(id_codes_tr.shape, id_codes_te.shape)
real_fake_train = np.concatenate((real_imgs_tr, gen_imgs_tr),0)
real_fake_test = np.concatenate((real_imgs_te, gen_imgs_te),0)
labels_both_train = np.concatenate((labels_tr, labels_tr), 0)
labels_both_te = np.concatenate((labels_te, labels_te), 0)

In [None]:
# define datasets
class Sprites(data.Dataset):
    def __init__(self, imgs, labels):
        self.imgs = imgs
        self.labels = labels
        
    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, idx):
        imgs = self.imgs[idx]
        labels = self.labels[idx]
        return imgs, labels

class Codes(data.Dataset):
    def __init__(self, codes, labels):
        self.codes = codes
        self.labels = labels
        
    def __len__(self):
        return len(self.codes)

    def __getitem__(self, idx):
        codes = self.codes[idx]
        labels = self.labels[idx]
        return codes, labels

In [None]:
# define classifier model and metrics

class SpriteClassifier(nn.Module):
    def __init__(self, n_bodies=7, n_shirts=4, n_pants=5, n_hairstyles=6, n_actions=9,
                 num_frames=8, in_size=64, channels=64, code_dim=1024, hidden_dim=512, nonlinearity=None):
        super(SpriteClassifier, self).__init__()
        nl = nn.LeakyReLU(0.2) if nonlinearity is None else nonlinearity
        encoding_conv = []
        encoding_conv.append(nn.Sequential(nn.Conv2d(3, channels, 5, 4, 1, bias=False), nl))
        size = in_size // 4
        self.num_frames = num_frames
        while size > 4:
            encoding_conv.append(nn.Sequential(
                nn.Conv2d(channels, channels * 2, 5, 4, 1, bias=False),
                nn.BatchNorm2d(channels * 2), nl))
            size = size // 4
            channels *= 2
        self.encoding_conv = nn.Sequential(*encoding_conv)
        self.final_size = size
        self.final_channels = channels
        self.code_dim = code_dim
        self.hidden_dim = hidden_dim
        self.encoding_fc = nn.Sequential(
                nn.Linear(size * size * channels, code_dim),
                nn.BatchNorm1d(code_dim), nl)
        # The last hidden state of a convolutional LSTM over the scenes is used for classification
        self.classifier_lstm = nn.LSTM(code_dim, hidden_dim, batch_first=True, bidirectional=False)
        self.body = nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim // 2),
                nn.BatchNorm1d(hidden_dim // 2), nl,
                nn.Linear(hidden_dim // 2, n_bodies))
        self.shirt = nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim // 2),
                nn.BatchNorm1d(hidden_dim // 2), nl,
                nn.Linear(hidden_dim // 2, n_shirts))
        self.pants = nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim // 2),
                nn.BatchNorm1d(hidden_dim // 2), nl,
                nn.Linear(hidden_dim // 2, n_pants))
        self.hairstyles = nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim // 2),
                nn.BatchNorm1d(hidden_dim // 2), nl,
                nn.Linear(hidden_dim // 2, n_hairstyles))
        self.action = nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim // 2),
                nn.BatchNorm1d(hidden_dim // 2), nl,
                nn.Linear(hidden_dim // 2, n_actions))

    def forward(self, x):
        x = x.view(-1, x.size(2), x.size(3), x.size(4))
        x = self.encoding_conv(x)
        x = x.view(-1, self.final_channels * (self.final_size ** 2))
        x = self.encoding_fc(x)
        x = x.view(-1, self.num_frames, self.code_dim)
        # Classifier output depends on last layer of LSTM: Can also change this to a bi-LSTM if required
        _, (hidden, _) = self.classifier_lstm(x)
        hidden = hidden.view(-1, self.hidden_dim)
        return self.body(hidden), self.shirt(hidden), self.pants(hidden), self.hairstyles(hidden), self.action(hidden)

def check_accuracy(model, test_data, device):
    total = 0
    correct_body = 0
    correct_shirt = 0
    correct_pant = 0
    correct_hair = 0
    correct_action = 0
    with torch.no_grad():
        for item in test_data:
            image, label = item
            image = image.to(device)
            body = label[:, 0].to(device)
            shirt = label[:, 1].to(device)
            pant = label[:, 2].to(device)
            hair = label[:, 3].to(device)
            action = label[:, 4].to(device)
            pred_body, pred_shirt, pred_pant, pred_hair, pred_action = model(image)
            _, pred_body = torch.max(pred_body.data, 1)
            _, pred_shirt = torch.max(pred_shirt.data, 1)
            _, pred_pant = torch.max(pred_pant.data, 1)
            _, pred_hair = torch.max(pred_hair.data, 1)
            _, pred_action = torch.max(pred_action.data, 1)
            total += body.size(0)
            correct_body += (pred_body == body).sum().item()
            correct_shirt += (pred_shirt == shirt).sum().item()
            correct_pant += (pred_pant == pant).sum().item()
            correct_hair += (pred_hair == hair).sum().item()
            correct_action += (pred_action == action).sum().item()
    print('Accuracy, Body : {} Shirt : {} Pant : {} Hair : {} Action {}'.format(correct_body/total, correct_shirt/total, correct_pant/total, correct_hair/total, correct_action/total)) 


def train_classifier(model, optim, train_data, device, epochs, path, test_data, start=0):
    model.train()
    criterion = nn.CrossEntropyLoss()
    for epoch in range(start, epochs):
        running_loss = 0.0
        for i, item in enumerate(train_data, 1):
            image, label = item
            image = image.to(device)
            body = label[:, 0].to(device)
            shirt = label[:, 1].to(device)
            pant = label[:, 2].to(device)
            hair = label[:, 3].to(device)
            action = label[:, 4].to(device)
            pred_body, pred_shirt, pred_pant, pred_hair, pred_action = model(image)
            loss = criterion(pred_body, body) + criterion(pred_shirt, shirt) + criterion(pred_pant, pant) + criterion(pred_hair, hair) + criterion(pred_action, action)
            loss.backward()
            optim.step()
            running_loss += loss.item()
        print('Epoch {} Avg Loss {}'.format(epoch + 1, running_loss / i))
#         save_model(model, optim, epoch, path)
        check_accuracy(model, test_data, device)
     

In [None]:
# init model and train
device = torch.device('cuda:0')
model = SpriteClassifier()
model.to(device)
optim = torch.optim.Adam(model.parameters(), lr=0.0004)

sprites_train = Sprites(real_fake_train, labels_both_train)
# sprites_test = Sprites(real_fake_test, labels_both_te)
# sprites_train = Sprites(real_imgs_tr, labels_tr)
sprites_test = Sprites(gen_imgs_te, labels_te)
loader = data.DataLoader(sprites_train, batch_size=32, shuffle=True, num_workers=4)
loader_test = data.DataLoader(sprites_test, batch_size=64, shuffle=True, num_workers=4)
train_classifier(model, optim, loader, device, 50, './checkpoint_classifier.pth', loader_test) 

In [None]:

# compare real image preds against fake image preds
bs = 2
real_test = Sprites(real_imgs_te, labels_te)
fake_test = Sprites(gen_imgs_te, labels_te)
real_loader = data.DataLoader(real_test, batch_size=bs, shuffle=False, num_workers=4)
fake_loader = data.DataLoader(fake_test, batch_size=bs, shuffle=False, num_workers=4)

total = 0
correct_body = 0
correct_shirt = 0
correct_pant = 0
correct_hair = 0
correct_action = 0
model.eval()
for i, (item_real, item_fake) in enumerate(zip(real_loader, fake_loader)):
    
    if i % 10 == 0:
        print(i)
    fake_image, _ = item_fake
    fake_image = fake_image.to(device)
    image, _ = item_real
    image = image.to(device)
    im_check = image[0,0].permute(1,2,0).detach().cpu().numpy()
    im_check_fake = fake_image[0,0].permute(1,2,0).detach().cpu().numpy()

    # real
    pred_body, pred_shirt, pred_pant, pred_hair, pred_action = model(image)
    _, pred_body = torch.max(pred_body.data, 1)
    _, pred_shirt = torch.max(pred_shirt.data, 1)
    _, pred_pant = torch.max(pred_pant.data, 1)
    _, pred_hair = torch.max(pred_hair.data, 1)
    _, pred_action = torch.max(pred_action.data, 1)
    # fake
    fake_pred_body, fake_pred_shirt, fake_pred_pant, fake_pred_hair, fake_pred_action = model(fake_image)
    _, fake_pred_body = torch.max(fake_pred_body.data, 1)
    _, fake_pred_shirt = torch.max(fake_pred_shirt.data, 1)
    _, fake_pred_pant = torch.max(fake_pred_pant.data, 1)
    _, fake_pred_hair = torch.max(fake_pred_hair.data, 1)
    _, fake_pred_action = torch.max(fake_pred_action.data, 1)
    total += pred_body.size(0)
    correct_body += (pred_body == fake_pred_body).sum().item()
    correct_shirt += (pred_shirt == fake_pred_shirt).sum().item()
    correct_pant += (pred_pant == fake_pred_pant).sum().item()
    correct_hair += (pred_hair == fake_pred_hair).sum().item()
    correct_action += (pred_action == fake_pred_action).sum().item()
print('Accuracy, Body : {} Shirt : {} Pant : {} Hair : {} Action {}'.format(correct_body/total, correct_shirt/total, correct_pant/total, correct_hair/total, correct_action/total)) 


    
    

In [None]:
# now test disentanglement using the id vs action codes
    
class CodeClassifier(nn.Module):
    def __init__(self, n_bodies=7, n_shirts=4, n_pants=5, n_hairstyles=6, n_actions=9,
                  in_size=15, code_dim=15, hidden_dim=256, nonlinearity=None):
        super(CodeClassifier, self).__init__()
        nl = nn.LeakyReLU(0.2) if nonlinearity is None else nonlinearity

        self.code_dim = code_dim
        self.hidden_dim = hidden_dim
        self.encoding_fc = nn.Sequential(
                nn.Linear(in_size, hidden_dim),
                nn.BatchNorm1d(hidden_dim), nl)
        
        self.body = nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim // 2),
                nn.BatchNorm1d(hidden_dim // 2), nl,
                nn.Linear(hidden_dim // 2, n_bodies))
        self.shirt = nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim // 2),
                nn.BatchNorm1d(hidden_dim // 2), nl,
                nn.Linear(hidden_dim // 2, n_shirts))
        self.pants = nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim // 2),
                nn.BatchNorm1d(hidden_dim // 2), nl,
                nn.Linear(hidden_dim // 2, n_pants))
        self.hairstyles = nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim // 2),
                nn.BatchNorm1d(hidden_dim // 2), nl,
                nn.Linear(hidden_dim // 2, n_hairstyles))
        self.action = nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim // 2),
                nn.BatchNorm1d(hidden_dim // 2), nl,
                nn.Linear(hidden_dim // 2, n_actions))

    def forward(self, x):
        x = self.encoding_fc(x)
        return self.body(x), self.shirt(x), self.pants(x), self.hairstyles(x), self.action(x)
    
       
    
device = torch.device('cuda:0')
model = CodeClassifier(in_size=30)
model.to(device)
optim = torch.optim.Adam(model.parameters(), lr=0.0003)

sprites_train = Codes(id_codes_tr, labels_tr)
# sprites_test = Sprites(real_fake_test, labels_both_te)
# sprites_train = Sprites(real_imgs_tr, labels_tr)
sprites_test = Codes(id_codes_te, labels_te)
loader = data.DataLoader(sprites_train, batch_size=32, shuffle=True, num_workers=4)
loader_test = data.DataLoader(sprites_test, batch_size=32, shuffle=True, num_workers=4)
train_classifier(model, optim, loader, device, 50, './checkpoint_classifier.pth', loader_test) 

device = torch.device('cuda:0')
model = CodeClassifier(in_size=10)
model.to(device)
optim = torch.optim.Adam(model.parameters(), lr=0.0004)

sprites_train = Codes(action_codes_tr, labels_tr)
# sprites_test = Sprites(real_fake_test, labels_both_te)
# sprites_train = Sprites(real_imgs_tr, labels_tr)
sprites_test = Codes(action_codes_te, labels_te)
loader = data.DataLoader(sprites_train, batch_size=32, shuffle=True, num_workers=4)
loader_test = data.DataLoader(sprites_test, batch_size=32, shuffle=True, num_workers=4)
train_classifier(model, optim, loader, device, 100, './checkpoint_classifier.pth', loader_test) 