In [1]:
from functions import *
from comet.csk_feature_extract import CSKFeatureExtractor

import numpy as np

2023-07-09 17:50:45 | INFO | fairseq.tasks.text_to_speech | Please install tensorboardX: pip install tensorboardX


# Constants

In [2]:
pretrained_roberta_path = './checkpoints/iemocap/checkpoint_best.pt'
roberta_data_path = 'iemocap-bin'
model_path = 'best_model_iemocap.pt'

# Loading Models

In [3]:
# Loading COMET and RoBerta
path, file = os.path.split(pretrained_roberta_path)
roberta = RobertaModel.from_pretrained(
    path,
    checkpoint_file=file,
    data_name_or_path=roberta_data_path
)
roberta.eval()

comet_extractor = CSKFeatureExtractor()

# Loading the Model
D_m = 1024
D_s = 768
D_g = 150
D_p = 150
D_r = 150
D_i = 150
D_h = 100
D_a = 100
D_e = D_p + D_r + D_i

model = CommonsenseGRUModel(D_m, D_s, D_g, D_p, D_r, D_i, D_e, D_h, D_a,
                                n_classes=6,
                                listener_state=True,
                                context_attention="general2",
                                dropout_rec=0.1,
                                dropout=0.25,
                                emo_gru=True,
                                mode1=2,
                                norm=3,
                                residual=False)
model.load_state_dict(
    torch.load(model_path, map_location=torch.device("cpu"))
)

2023-07-09 17:50:45 | INFO | fairseq.file_utils | loading archive file ./checkpoints/iemocap
2023-07-09 17:50:45 | INFO | fairseq.file_utils | loading archive file iemocap-bin
2023-07-09 17:50:47 | INFO | fairseq.tasks.sentence_prediction | [input] dictionary: 50265 types
2023-07-09 17:50:47 | INFO | fairseq.tasks.sentence_prediction | [label] dictionary: 17 types
2023-07-09 17:50:55 | INFO | fairseq.models.roberta.model | {'_name': None, 'common': {'_name': None, 'no_progress_bar': False, 'log_interval': 100, 'log_format': 'simple', 'log_file': None, 'aim_repo': None, 'aim_run_hash': None, 'tensorboard_logdir': None, 'wandb_project': None, 'azureml_logging': False, 'seed': 1, 'cpu': False, 'tpu': False, 'bf16': False, 'memory_efficient_bf16': False, 'fp16': False, 'memory_efficient_fp16': False, 'fp16_no_flatten_grads': False, 'fp16_init_scale': 128, 'fp16_scale_window': None, 'fp16_scale_tolerance': 0.0, 'on_cpu_convert_precision': False, 'min_loss_scale': 0.0001, 'threshold_loss_sca

Loading data from: comet/data/atomic/processed/generation/categories_oEffect#oReact#oWant#xAttr#xEffect#xIntent#xNeed#xReact#xWant-maxe1_17-maxe2_35-maxr_1.pickle
52


<All keys matched successfully>

# Feature Extraction

In [4]:
# Preprocessing Sentence
speakers, sentences = preprocess_text('./sentences/example_sentence_2.txt') # You can change this path and test different sentences

In [5]:
# Extracting Features

# RoBerta
batch = collate_tokens([roberta.encode(s) for s in sentences], pad_idx=1)
feat = roberta.extract_features(batch, return_all_hiddens=True)
r1 = torch.unsqueeze(torch.FloatTensor(np.array([row for row in feat[-1][:, 0, :].detach().numpy()])), dim=0)
r2 = torch.unsqueeze(torch.FloatTensor(np.array([row for row in feat[-2][:, 0, :].detach().numpy()])), dim=0)
r3 = torch.unsqueeze(torch.FloatTensor(np.array([row for row in feat[-3][:, 0, :].detach().numpy()])), dim=0)
r4 = torch.unsqueeze(torch.FloatTensor(np.array([row for row in feat[-4][:, 0, :].detach().numpy()])), dim=0)

# COMET
comet_features = comet_extractor.extract(sentences)
x1, x2, x3, x4, x5, x6, o1, o2, o3 = [torch.unsqueeze(torch.FloatTensor(data), dim=0) for data in comet_features]
# Masks
qmask = torch.unsqueeze(torch.FloatTensor([[1,0] if x=='M' else [0,1] for x in speakers]), dim=0)
umask = torch.unsqueeze(torch.FloatTensor([1]*len(speakers)), dim=1)

  x1, x2, x3, x4, x5, x6, o1, o2, o3 = [torch.unsqueeze(torch.FloatTensor(data), dim=0) for data in comet_features]


# Getting Predictions

In [6]:
preds = []
model.eval()
log_prob, _, alpha, alpha_f, alpha_b, _ = model(r1, r2, r3, r4, x5, x6, x1, o2, o3, qmask, umask, att2=True)
lp_ = log_prob.transpose(0,1).contiguous().view(-1, log_prob.size()[2]) # batch*seq_len, n_classes
pred_ = torch.argmax(lp_,1) # batch*seq_len
preds.append(pred_.data.cpu().numpy())
preds  = np.concatenate(preds)
log_probs = torch.squeeze(log_prob, dim=0)

# Plotting Preds

In [7]:
plot_emotions_1(log_probs, speakers, 'plots/plot1.png')
plot_emotions_2(log_probs, speakers, 'plots/plot2.png')
plot_sentences(sentences, preds, speakers, './plots/sentence.png')