In [2]:
import numpy as np
import pandas as pd
import json

import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from sklearn.metrics import accuracy_score, cohen_kappa_score, f1_score, classification_report

In [5]:
class Net(nn.Module):
    def __init__(self, INPUT_SIZE):
        super(Net, self).__init__()
        # Input is output of resnet18 image classification model
        self.fc1 = nn.Linear(INPUT_SIZE, 500) 
        self.fc11 = nn.Linear(500, 100)
        self.fc2 = nn.Linear(100, 1)
        self.act_out = nn.Sigmoid()

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc11(x))
        x = self.act_out(self.fc2(x))
        return x


def load_dataset_df(label_path):
    label = pd.read_csv(label_path)
    label = label.drop(columns=['Unnamed: 0'])
    dataset = label

    # load embeddings
    with open('./dataset_work/embeddings.json') as f:
        embeddings = json.load(f)
    all_embeddings = pd.DataFrame.from_dict(embeddings, orient='index', columns=['embedding'])
    all_embeddings.index.name = 'id'
    dataset = pd.merge(dataset, all_embeddings, on='id', how='left')

    # load personalities
    with open('./dataset_work/personalities.json') as f:
        personalities = json.load(f)
    all_personalities = pd.DataFrame.from_dict(personalities, orient='index', columns=['playfulness', 'chase-proneness', 'curiosity', 'sociability', 'aggressiveness', 'shyness'])
    all_personalities.index.name = 'id'
    dataset = pd.merge(dataset, all_personalities, on='id', how='left')

    return dataset

def vectorize_dataset(df, with_personalities, with_embeddings):
    if with_personalities and with_embeddings:
        cols = df[['embedding','playfulness', 'chase-proneness', 'curiosity', 'sociability', 'aggressiveness', 'shyness']]
        x = cols.values
        embeddings = x[:,0]
        expanded = np.array(embeddings.tolist())
        x = np.concatenate((expanded, x[:,1:]), axis=1)
        x = np.array(x, dtype='float64')
        y = df['label'].values
    elif with_embeddings:
        cols = df[['embedding','playfulness', 'chase-proneness', 'curiosity', 'sociability', 'aggressiveness', 'shyness']]
        x = cols.values
        embeddings = x[:,0]
        expanded = np.array(embeddings.tolist())
        x = expanded
        y = df['label'].values
    else:
        cols = df[['embedding','playfulness', 'chase-proneness', 'curiosity', 'sociability', 'aggressiveness', 'shyness']]
        x = cols.values
        x = x[:,1:]
        x = np.array(x, dtype="float64")
        y = df['label'].values
    return torch.tensor(x, dtype=torch.float32), torch.tensor(y, dtype=torch.float32).reshape(-1, 1)

def load(label_path, with_personalities, with_embeddings):
    df = load_dataset_df(label_path)
    return vectorize_dataset(df, with_personalities, with_embeddings)

def pred(label_path):
    PATH = "trained_models/alice_medium_images_True_personalities_False_optm_adam_loss_MSE_EPOCHS_150_BATCH_1000_REG_0.001_LR_0.0001_f1_0.7230682829943155"
    train_x_t, train_y_t = load(label_path, with_personalities=False, with_embeddings=True)
    model = Net(train_x_t.shape[1])
    model.load_state_dict(torch.load(PATH))
    model.eval()
    return model(train_x_t).round()

In [6]:
pred('dataset_work/labels/personality_only/alice_train_personalityTrue_imageFalse_labels.csv')

tensor([[1.],
        [1.],
        [1.],
        [0.],
        [0.],
        [0.],
        [0.],
        [1.],
        [1.],
        [0.],
        [1.],
        [1.],
        [0.],
        [0.],
        [0.],
        [1.],
        [0.],
        [0.],
        [1.],
        [1.],
        [0.],
        [0.],
        [0.],
        [0.],
        [1.],
        [1.],
        [1.],
        [0.],
        [1.],
        [0.],
        [1.],
        [0.],
        [0.],
        [0.],
        [1.],
        [1.],
        [0.],
        [0.],
        [1.],
        [0.],
        [1.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [1.],
        [1.],
        [0.],
        [0.],
        [1.],
        [1.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [1.],
        [0.],
        [1.],
        [1.],
        [0.],
        [1.],
        [0.],
        [1.],
        [1.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
      