Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
soujanyaporia committed Oct 28, 2019
1 parent 230695b commit 810e931
Showing 1 changed file with 17 additions and 4 deletions.
21 changes: 17 additions & 4 deletions DialogueRNN/train_MELD.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,12 @@ def train_or_eval_model(model, loss_function, dataloader, epoch, optimizer=None,
# import ipdb;ipdb.set_trace()
textf, acouf, qmask, umask, label =\
[d.cuda() for d in data[:-1]] if cuda else data[:-1]
# log_prob, alpha, alpha_f, alpha_b = model(torch.cat((textf,acouf),dim=-1), qmask,umask) # seq_len, batch, n_classes
log_prob, alpha, alpha_f, alpha_b = model(textf, qmask,umask) # seq_len, batch, n_classes
if feature_type == "audio":
log_prob, alpha, alpha_f, alpha_b = model(acouf, qmask,umask) # seq_len, batch, n_classes
elif feature_type == "text":
log_prob, alpha, alpha_f, alpha_b = model(textf, qmask,umask) # seq_len, batch, n_classes
else:
log_prob, alpha, alpha_f, alpha_b = model(torch.cat((textf,acouf),dim=-1), qmask,umask) # seq_len, batch, n_classes
lp_ = log_prob.transpose(0,1).contiguous().view(-1,log_prob.size()[2]) # batch*seq_len, n_classes
labels_ = label.view(-1) # batch*seq_len
loss = loss_function(lp_, labels_, umask)
Expand Down Expand Up @@ -119,7 +123,8 @@ def train_or_eval_model(model, loss_function, dataloader, epoch, optimizer=None,
writer = SummaryWriter()

# choose between 'sentiment' or 'emotion'
classification_type = 'sentiment'
classification_type = 'emotion'
feature_type = 'multimodal'

data_path = 'DialogueRNN_features/MELD_features/'
batch_size = 30
Expand All @@ -133,7 +138,15 @@ def train_or_eval_model(model, loss_function, dataloader, epoch, optimizer=None,
l2 = 0.00001
lr = 0.0005

D_m = 600
if feature_type == 'text':
print("Running on the text features........")
D_m = 600
elif feature_type == 'audio':
print("Running on the audio features........")
D_m = 300
else:
print("Running on the multimodal features........")
D_m = 900
D_g = 150
D_p = 150
D_e = 100
Expand Down

0 comments on commit 810e931

Please sign in to comment.