In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import nltk
import numpy as np

from nltk.corpus import brown
from nltk.corpus import treebank
from nltk.corpus import conll2000

%load_ext autoreload
%autoreload 2

In [2]:
nltk.download('brown')
nltk.download('treebank')
nltk.download('conll2000')
nltk.download('universal_tagset')

[nltk_data] Downloading package brown to
[nltk_data]     /home/grad3/peilun/nltk_data...
[nltk_data]   Package brown is already up-to-date!
[nltk_data] Downloading package treebank to
[nltk_data]     /home/grad3/peilun/nltk_data...
[nltk_data]   Package treebank is already up-to-date!
[nltk_data] Downloading package conll2000 to
[nltk_data]     /home/grad3/peilun/nltk_data...
[nltk_data]   Package conll2000 is already up-to-date!
[nltk_data] Downloading package universal_tagset to
[nltk_data]     /home/grad3/peilun/nltk_data...
[nltk_data]   Package universal_tagset is already up-to-date!


True

In [3]:
print("brown words number: ", len(brown.words()))
print("treebank words number: ", len(treebank.words()))
print("conll2000 words number: ", len(conll2000.words()))


brown words number:  1161192
treebank words number:  100676
conll2000 words number:  259104


In [4]:
treebank_corpus = treebank.tagged_sents(tagset='universal')
brown_corpus = brown.tagged_sents(tagset='universal')
conll_corpus = conll2000.tagged_sents(tagset='universal')
tagged_sentences = treebank_corpus + brown_corpus + conll_corpus

In [5]:
print(len(brown.words())/len(brown_corpus))
print(len(treebank.words())/len(treebank_corpus))
print(len(conll2000.words())/len(conll_corpus))

20.250994070456922
25.722023505365357
23.66678845451224


In [None]:
type(brown.words())

In [None]:
tagged_sentences[7]

In [None]:
X = [] # store input sequence
Y = [] # store output sequence

for sentence in tagged_sentences:
    
    X_sentence = []
    Y_sentence = []
    for entity in sentence: 
        X_sentence.append(entity[0]) # entity[0] contains the word
        Y_sentence.append(entity[1]) # entity[1] contains corresponding tag
    X.append(X_sentence)
    Y.append(Y_sentence)
num_words = len(set([word.lower() for sentence in X for word in sentence]))
num_tags   = len(set([word.lower() for sentence in Y for word in sentence]))
print("Total number of tagged sentences: {}".format(len(X)))
print("Vocabulary size: {}".format(num_words))
print("Total number of tags: {}".format(num_tags))

In [None]:
# let’s look at first data point
# this is one data point that will be fed to the RNN

print('sample X:', X[0],'\n')
print('sample Y:', Y[0],'\n')

In [None]:
print("Length of first input sequence : {}".format(len(X[0])))
print("Length of first output sequence : {}".format(len(Y[0])))

In [None]:
torch.save(X, 'X.pt')

In [None]:
Z = torch.load('X.pt')

In [6]:

# https://github.com/bentrevett/pytorch-pos-tagging/blob/master/1%20-%20BiLSTM%20for%20PoS%20Tagging.ipynb

import torch
import torch.nn as nn
import torch.optim as optim

import torchtext

import spacy
import numpy as np

import time
import random

In [7]:
SEED = 1234

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

In [8]:
TEXT = torchtext.data.Field(lower = True)
UD_TAGS = torchtext.data.Field(unk_token = None)
PTB_TAGS = torchtext.data.Field(unk_token = None)



In [9]:
fields = (("text", TEXT), ("udtags", UD_TAGS), ("ptbtags", PTB_TAGS))

In [11]:
train_data, valid_data, test_data = torchtext.datasets.UDPOS.splits(fields)



In [12]:
print(f"Number of training examples: {len(train_data)}")
print(f"Number of validation examples: {len(valid_data)}")
print(f"Number of testing examples: {len(test_data)}")

Number of training examples: 12543
Number of validation examples: 2002
Number of testing examples: 2077


In [13]:
print(vars(train_data.examples[0])['text'])

['al', '-', 'zaman', ':', 'american', 'forces', 'killed', 'shaikh', 'abdullah', 'al', '-', 'ani', ',', 'the', 'preacher', 'at', 'the', 'mosque', 'in', 'the', 'town', 'of', 'qaim', ',', 'near', 'the', 'syrian', 'border', '.']


In [14]:
print(vars(train_data.examples[0])['udtags'])

['PROPN', 'PUNCT', 'PROPN', 'PUNCT', 'ADJ', 'NOUN', 'VERB', 'PROPN', 'PROPN', 'PROPN', 'PUNCT', 'PROPN', 'PUNCT', 'DET', 'NOUN', 'ADP', 'DET', 'NOUN', 'ADP', 'DET', 'NOUN', 'ADP', 'PROPN', 'PUNCT', 'ADP', 'DET', 'ADJ', 'NOUN', 'PUNCT']


In [15]:
print(vars(train_data.examples[0])['ptbtags'])

['NNP', 'HYPH', 'NNP', ':', 'JJ', 'NNS', 'VBD', 'NNP', 'NNP', 'NNP', 'HYPH', 'NNP', ',', 'DT', 'NN', 'IN', 'DT', 'NN', 'IN', 'DT', 'NN', 'IN', 'NNP', ',', 'IN', 'DT', 'JJ', 'NN', '.']


In [16]:
MIN_FREQ = 2
TEXT.build_vocab(train_data, 
                 min_freq = MIN_FREQ,
                 vectors = "glove.6B.100d",
                 unk_init = torch.Tensor.normal_)
UD_TAGS.build_vocab(train_data)
PTB_TAGS.build_vocab(train_data)

In [17]:
print(f"Unique tokens in TEXT vocabulary: {len(TEXT.vocab)}")
print(f"Unique tokens in UD_TAG vocabulary: {len(UD_TAGS.vocab)}")
print(f"Unique tokens in PTB_TAG vocabulary: {len(PTB_TAGS.vocab)}")

Unique tokens in TEXT vocabulary: 8866
Unique tokens in UD_TAG vocabulary: 18
Unique tokens in PTB_TAG vocabulary: 51


In [18]:
print(TEXT.vocab.freqs.most_common(20))

[('the', 9076), ('.', 8640), (',', 7021), ('to', 5137), ('and', 5002), ('a', 3782), ('of', 3622), ('i', 3379), ('in', 3112), ('is', 2239), ('you', 2156), ('that', 2036), ('it', 1850), ('for', 1842), ('-', 1426), ('have', 1359), ('"', 1296), ('on', 1273), ('was', 1244), ('with', 1216)]


In [19]:
print(UD_TAGS.vocab.itos)

['<pad>', 'NOUN', 'PUNCT', 'VERB', 'PRON', 'ADP', 'DET', 'PROPN', 'ADJ', 'AUX', 'ADV', 'CCONJ', 'PART', 'NUM', 'SCONJ', 'X', 'INTJ', 'SYM']


In [20]:
print(UD_TAGS.vocab.freqs.most_common())

[('NOUN', 34781), ('PUNCT', 23679), ('VERB', 23081), ('PRON', 18577), ('ADP', 17638), ('DET', 16285), ('PROPN', 12946), ('ADJ', 12477), ('AUX', 12343), ('ADV', 10548), ('CCONJ', 6707), ('PART', 5567), ('NUM', 3999), ('SCONJ', 3843), ('X', 847), ('INTJ', 688), ('SYM', 599)]


In [21]:
print(PTB_TAGS.vocab.freqs.most_common())

[('NN', 26915), ('IN', 20724), ('DT', 16817), ('NNP', 12449), ('PRP', 12193), ('JJ', 11591), ('RB', 10831), ('.', 10317), ('VB', 9476), ('NNS', 8438), (',', 8062), ('CC', 6706), ('VBD', 5402), ('VBP', 5374), ('VBZ', 4578), ('CD', 3998), ('VBN', 3967), ('VBG', 3330), ('MD', 3294), ('TO', 3286), ('PRP$', 3068), ('-RRB-', 1008), ('-LRB-', 973), ('WDT', 948), ('WRB', 869), (':', 866), ('``', 813), ("''", 785), ('WP', 760), ('RP', 755), ('UH', 689), ('POS', 684), ('HYPH', 664), ('JJR', 503), ('NNPS', 498), ('JJS', 383), ('EX', 359), ('NFP', 338), ('GW', 294), ('ADD', 292), ('RBR', 276), ('$', 258), ('PDT', 175), ('RBS', 169), ('SYM', 156), ('LS', 117), ('FW', 93), ('AFX', 48), ('WP$', 15), ('XX', 1)]


In [22]:
def tag_percentage(tag_counts):
    
    total_count = sum([count for tag, count in tag_counts])
    
    tag_counts_percentages = [(tag, count, count/total_count) for tag, count in tag_counts]
        
    return tag_counts_percentages

In [23]:
print("Tag\t\tCount\t\tPercentage\n")

for tag, count, percent in tag_percentage(UD_TAGS.vocab.freqs.most_common()):
    print(f"{tag}\t\t{count}\t\t{percent*100:4.1f}%")


Tag		Count		Percentage

NOUN		34781		17.0%
PUNCT		23679		11.6%
VERB		23081		11.3%
PRON		18577		 9.1%
ADP		17638		 8.6%
DET		16285		 8.0%
PROPN		12946		 6.3%
ADJ		12477		 6.1%
AUX		12343		 6.0%
ADV		10548		 5.2%
CCONJ		6707		 3.3%
PART		5567		 2.7%
NUM		3999		 2.0%
SCONJ		3843		 1.9%
X		847		 0.4%
INTJ		688		 0.3%
SYM		599		 0.3%


In [24]:
print("Tag\t\tCount\t\tPercentage\n")

for tag, count, percent in tag_percentage(PTB_TAGS.vocab.freqs.most_common()):
    print(f"{tag}\t\t{count}\t\t{percent*100:4.1f}%")

Tag		Count		Percentage

NN		26915		13.2%
IN		20724		10.1%
DT		16817		 8.2%
NNP		12449		 6.1%
PRP		12193		 6.0%
JJ		11591		 5.7%
RB		10831		 5.3%
.		10317		 5.0%
VB		9476		 4.6%
NNS		8438		 4.1%
,		8062		 3.9%
CC		6706		 3.3%
VBD		5402		 2.6%
VBP		5374		 2.6%
VBZ		4578		 2.2%
CD		3998		 2.0%
VBN		3967		 1.9%
VBG		3330		 1.6%
MD		3294		 1.6%
TO		3286		 1.6%
PRP$		3068		 1.5%
-RRB-		1008		 0.5%
-LRB-		973		 0.5%
WDT		948		 0.5%
WRB		869		 0.4%
:		866		 0.4%
``		813		 0.4%
''		785		 0.4%
WP		760		 0.4%
RP		755		 0.4%
UH		689		 0.3%
POS		684		 0.3%
HYPH		664		 0.3%
JJR		503		 0.2%
NNPS		498		 0.2%
JJS		383		 0.2%
EX		359		 0.2%
NFP		338		 0.2%
GW		294		 0.1%
ADD		292		 0.1%
RBR		276		 0.1%
$		258		 0.1%
PDT		175		 0.1%
RBS		169		 0.1%
SYM		156		 0.1%
LS		117		 0.1%
FW		93		 0.0%
AFX		48		 0.0%
WP$		15		 0.0%
XX		1		 0.0%


In [123]:
BATCH_SIZE = 100
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

train_iterator, valid_iterator, test_iterator = torchtext.data.BucketIterator.splits(
    (train_data, valid_data, test_data), 
    batch_size = BATCH_SIZE,
    device = device)

In [124]:
len(train_iterator)

126

In [125]:
a=next(iter(train_iterator))
print(a.udtags.dtype)
# print(a.text.shape)

torch.int64


In [126]:
class VanillaRNN(nn.Module):
    # Assuming only 
    
    def __init__(self, 
                 input_dim, # vocab size
                 embedding_dim, 
                 hidden_dim, 
                 output_dim, 
                 pad_idx):
    
        
        super().__init__()
        
        self.embedding = nn.Embedding(input_dim, embedding_dim, padding_idx = pad_idx)
        self.embedding.requires_grad = False # do not change pretrainined embedding
        self.input_dim = input_dim
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        self.Wb_xh = nn.Linear(embedding_dim, hidden_dim, bias=True)
        self.Wb_hh = nn.Linear(hidden_dim, hidden_dim, bias=True)
        self.Wb_hy = nn.Linear(hidden_dim, output_dim, bias=True)
        
        
    def forward(self, text):

        #text = [sent len, batch size]
        
        #pass text through embedding layer
        embedded = self.embedding(text)
        sent_len, batch_sz, embed_dim = embedded.shape 
        
        #embedded = [sent len, batch size, emb dim]
        
        #pass embeddings into LSTM
#         outputs, (hidden, cell) = self.lstm(embedded)
        # mannuly do RNN
        h = torch.zeros(1, batch_sz, self.hidden_dim, device=text.device)
        Ys_logits = []
        Hs = []
        for i in range(sent_len):
            h = torch.tanh(self.Wb_xh(embedded[i,:,:]) + self.Wb_hh(h))
            Hs.append(h)
            y = self.Wb_hy(h)
            Ys_logits.append(y)
        
        return torch.cat(Ys_logits, dim=0) # sent_len * batch_sz * output_dim
    
    

            
        

INPUT_DIM = len(TEXT.vocab)
EMBEDDING_DIM = 100
HIDDEN_DIM = 128
OUTPUT_DIM = len(UD_TAGS.vocab)
PAD_IDX = TEXT.vocab.stoi[TEXT.pad_token] # not used for our purpose


model = VanillaRNN(INPUT_DIM, EMBEDDING_DIM, HIDDEN_DIM, OUTPUT_DIM, PAD_IDX).to(device)

# def count_parameters(model):
#     return sum(p.numel() for p in model.parameters() if p.requires_grad)

# print(f'The model has {count_parameters(model):,} trainable parameters')


pretrained_embeddings = TEXT.vocab.vectors

# print(pretrained_embeddings.shape)

model.embedding.weight.data[PAD_IDX] = torch.zeros(EMBEDDING_DIM).to(device)
# print(model.embedding.weight.data)
model.embedding.requires_grad = False

optimizer = optim.Adam(model.parameters())
TAG_PAD_IDX = UD_TAGS.vocab.stoi[UD_TAGS.pad_token]
criterion = nn.CrossEntropyLoss(ignore_index = TAG_PAD_IDX)

model = model.to(device)
criterion = criterion.to(device)

def categorical_accuracy(preds, y, tag_pad_idx):
    """
    Returns accuracy per batch, i.e. if you get 8/10 right, this returns 0.8, NOT 8
    """
    max_preds = preds.argmax(dim = 1, keepdim = True) # get the index of the max probability
    non_pad_elements = (y != tag_pad_idx).nonzero()
    correct = max_preds[non_pad_elements].squeeze(1).eq(y[non_pad_elements])
    return correct.sum() / torch.tensor([y[non_pad_elements].shape[0]], device=device)

def train(model, iterator, optimizer, criterion, tag_pad_idx):
    
    epoch_loss = 0
    epoch_acc = 0
    
    model.train()
    
    for batch in iterator:
        
        text = batch.text
        tags = batch.udtags
        
        optimizer.zero_grad()
        
        #text = [sent len, batch size]
        
        predictions = model(text)
        
        #predictions = [sent len, batch size, output dim]
        #tags = [sent len, batch size]
        
        predictions = predictions.view(-1, predictions.shape[-1])
        tags = tags.view(-1)
        
        #predictions = [sent len * batch size, output dim]
        #tags = [sent len * batch size]
        
        loss = criterion(predictions, tags)
                
        acc = categorical_accuracy(predictions, tags, tag_pad_idx)
        
        loss.backward()
        
        optimizer.step()
        
        epoch_loss += loss.item()
        epoch_acc += acc.item()
        
    return epoch_loss / len(iterator), epoch_acc / len(iterator)

def evaluate(model, iterator, criterion, tag_pad_idx):
    
    epoch_loss = 0
    epoch_acc = 0
    
    model.eval()
    
    with torch.no_grad():
    
        for batch in iterator:

            text = batch.text
            tags = batch.udtags
            
            predictions = model(text)
            
            predictions = predictions.view(-1, predictions.shape[-1])
            tags = tags.view(-1)
            
            loss = criterion(predictions, tags)
            
            acc = categorical_accuracy(predictions, tags, tag_pad_idx)

            epoch_loss += loss.item()
            epoch_acc += acc.item()
        
    return epoch_loss / len(iterator), epoch_acc / len(iterator)

def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [120]:
class seqTP_POS(nn.Module):
    
    def __init__(self, 
                 input_dim, # vocab size
                 embedding_dim, 
                 hidden_dim, 
                 output_dim, 
                 pad_idx):
    
        super().__init__()
        
        self.embedding = nn.Embedding(input_dim, embedding_dim, padding_idx = pad_idx)
        self.input_dim = input_dim
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        self.Wb_xh = nn.Linear(embedding_dim, hidden_dim, bias=True)
        self.Wb_hh = nn.Linear(hidden_dim, hidden_dim, bias=True)
        self.Wb_hy = nn.Linear(hidden_dim, output_dim, bias=True)
        
        self.Vc_hh = nn.Linear(hidden_dim, hidden_dim, bias=True)
        self.mse_loss = nn.MSELoss()
        
        
        
    def forward(self, text):

        #text = [sent len, batch size]
        
        embedded = self.embedding(text)
        #embedded = [sent len, batch size, emb dim]
        sent_len, batch_sz, embed_dim = embedded.shape 

        
        # mannuly do RNN
        h = torch.zeros(1, batch_sz, self.hidden_dim, device=text.device)
        Ys_logits = []
        Hs = []
        
        for i in range(sent_len):
            h = torch.tanh(self.Wb_xh(embedded[i,:,:]) + self.Wb_hh(h))
            Hs.append(h)
            y = self.Wb_hy(h)
            Ys_logits.append(y)
        
        return torch.cat(Ys_logits, dim=0) # [sent_len, batch_sz, output_dim]
    
    def comp_losses(self, text, tags, criterion, i_lr):
        
        embedded = self.embedding(text)
        sent_len, batch_sz, embed_dim = embedded.shape 
        

        h = torch.zeros(1, batch_sz, self.hidden_dim, device=text.device)
        Hs_activations = []
        Hs_targets_from_y = [] # h_tar = y - i_lr * delta local_loss / delta h
        
        for i in range(sent_len):
            
            # Forward activations
            h = torch.tanh(self.Wb_xh(embedded[i,:,:]) + self.Wb_hh(h.detach())) # breaking chain rule
            Hs_activations.append(h)
            
            # Local targets
            h_for_y = h.detach()
            h_for_y.requires_grad = True
            y = self.Wb_hy(h_for_y) # only depends on local Wb_hy and h, breaking chain rule

            y = y.view(-1, y.shape[-1])
            tag = tags[i,:].view(-1)
            
            word_loss = criterion(y, tag)
            word_loss.backward() # accumute grad for Wb_hy (and h_for_y, whic is only used here)
            
            # used for Wb_hy updats, do not zero out
            self.Wb_hy.weight.grad.zero_()
            self.Wb_hy.bias.grad.zero_()
            Hs_targets_from_y.append((h.detach() - i_lr * h_for_y.grad.detach()).detach())

        # Backward with inverse functions G
        counter = torch.zeros(batch_sz, device=text.device)
        
        Y_total_loss = torch.tensor(0., device=text.device)
        F_total_loss = torch.tensor(0., device=text.device)
        G_total_loss = torch.tensor(0., device=text.device)
        
        for i in range(sent_len-1, -1, -1):
            tag = tags[i,:].view(-1) # size: batch_sz, dtype = torch.int64
            counter += (tag != 0)
            
            if i == sent_len - 1:
                Hs_targets_from_y[i] = Hs_targets_from_y[i].detach()
                
                # y loss for last step 
                y_logits = self.Wb_hy(Hs_activations[i].detach()) # only depends on local Wb_hy
                y_logits = y_logits.view(-1, y_logits.shape[-1])
                tag = tags[i,:].view(-1)
                y_loss = criterion(y_logits, tag)
                Y_total_loss += y_loss
                
            else:
                beta = (1./counter).view(1, -1, 1) # [1, batch_sz, 1]
                alpha =(1. - beta).view(1, -1, 1)
                # apply DTP
                with torch.no_grad():
                    from_embedding = self.Wb_xh(embedded[i+1,:,:])
                    G_h_t1 = torch.tanh(self.Vc_hh(Hs_activations[i+1]) + from_embedding)
                    G_h_tar_h1 = torch.tanh(self.Vc_hh(Hs_targets_from_y[i+1]) + from_embedding)
                    targets_from_future = Hs_activations[i] - G_h_t1 + G_h_tar_h1 # DTP

                    aa = alpha * targets_from_future
                    bb = beta * Hs_targets_from_y[i]
                    Hs_targets_from_y[i] = (aa+bb).detach()
        
                # Compute loss
                y_logits = self.Wb_hy(Hs_activations[i].detach()) # only depends on local Wb_hy
                y_logits = y_logits.view(-1, y_logits.shape[-1])
                tag = tags[i,:].view(-1)
                y_loss = criterion(y_logits, tag)
                Y_total_loss += y_loss

                # F and x loss, due to forward prediction
                F_total_loss += self.mse_loss(Hs_activations[i+1], Hs_targets_from_y[i+1])
                
                # G loss
                rec = torch.tanh(self.Vc_hh(Hs_activations[i+1].detach()) +  from_embedding) # not updating Wb_xh
                G_total_loss += self.mse_loss(Hs_activations[i].detach(), rec) 
                
        return Y_total_loss, F_total_loss, G_total_loss
    
INPUT_DIM = len(TEXT.vocab)
EMBEDDING_DIM = 100
HIDDEN_DIM = 128
OUTPUT_DIM = len(UD_TAGS.vocab)
PAD_IDX = TEXT.vocab.stoi[TEXT.pad_token] # not used for our purpose

model = seqTP_POS(INPUT_DIM, EMBEDDING_DIM, HIDDEN_DIM, OUTPUT_DIM, PAD_IDX).to(device)

pretrained_embeddings = TEXT.vocab.vectors

model.embedding.weight.data[PAD_IDX] = torch.zeros(EMBEDDING_DIM).to(device)
model.embedding.requires_grad = False # do not true embedding

optimizer = optim.Adam(model.parameters())

TAG_PAD_IDX = UD_TAGS.vocab.stoi[UD_TAGS.pad_token]
criterion = nn.CrossEntropyLoss(ignore_index = TAG_PAD_IDX)

model = model.to(device)
criterion = criterion.to(device)


In [121]:
def categorical_accuracy(preds, y, tag_pad_idx):
    """
    Returns accuracy per batch, i.e. if you get 8/10 right, this returns 0.8, NOT 8
    """
    max_preds = preds.argmax(dim = 1, keepdim = True) # get the index of the max probability
    non_pad_elements = (y != tag_pad_idx).nonzero()
    correct = max_preds[non_pad_elements].squeeze(1).eq(y[non_pad_elements])
    return correct.sum() / torch.tensor([y[non_pad_elements].shape[0]], device=device)

def train(model, iterator, optimizer, criterion, tag_pad_idx):
    
    G_epoch_loss = 0
    F_epoch_loss = 0
    Y_epoch_loss = 0
    epoch_acc = 0
    
    model.train()
    
    # train G
    for batch in iterator:
        
        text = batch.text #text = [sent len, batch size]
        
        tags = batch.udtags
        
        optimizer.zero_grad()        
    
        Y_total_loss, F_total_loss, G_total_loss = model.comp_losses(text, tags, criterion, 0.1)
        
        G_total_loss.backward()
        optimizer.step()
        G_epoch_loss += G_total_loss.item()
        
    
    # train F, Y
    for batch in iterator:
        
        text = batch.text #text = [sent len, batch size]
        tags = batch.udtags
        
        optimizer.zero_grad()
        
        
        
#         predictions = model(text)
        
#         #predictions = [sent len, batch size, output dim]
#         #tags = [sent len, batch size]
        
#         predictions = predictions.view(-1, predictions.shape[-1])
#         tags = tags.view(-1)
        
#         #predictions = [sent len * batch size, output dim]
#         #tags = [sent len * batch size]
        
#         loss = criterion(predictions, tags)
                
    
        Y_total_loss, F_total_loss, G_total_loss = model.comp_losses(text, tags, criterion, 0.1)
        
        F_total_loss.backward(retain_graph=True)
        Y_total_loss.backward()
        
        optimizer.step()
        
        F_epoch_loss += F_total_loss.item()
        Y_epoch_loss += Y_total_loss.item()
        
    return G_epoch_loss / len(iterator), F_epoch_loss / len(iterator), Y_epoch_loss / len(iterator)

def evaluate(model, iterator, criterion, tag_pad_idx):
    
    epoch_loss = 0
    epoch_acc = 0
    
    model.eval()
    
    with torch.no_grad():
    
        for batch in iterator:

            text = batch.text
            tags = batch.udtags
            
            predictions = model(text)
            
            predictions = predictions.view(-1, predictions.shape[-1])
            tags = tags.view(-1)
            
            loss = criterion(predictions, tags)
            
            acc = categorical_accuracy(predictions, tags, tag_pad_idx)

            epoch_loss += loss.item()
            epoch_acc += acc.item()
        
    return epoch_loss / len(iterator), epoch_acc / len(iterator)

def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [122]:
N_EPOCHS = 20

best_valid_loss = float('inf')

for epoch in range(N_EPOCHS):

    start_time = time.time()
    
#     train_loss, train_acc = train(model, train_iterator, optimizer, criterion, TAG_PAD_IDX)
    G_loss, F_loss, Y_loss = train(model, train_iterator, optimizer, criterion, TAG_PAD_IDX)
    valid_loss, valid_acc = evaluate(model, valid_iterator, criterion, TAG_PAD_IDX)
    
    end_time = time.time()

    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), 'tut1-model.pt')
    
    print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain G Loss: {G_loss:.3f}, \tTrain F Loss: {F_loss:.3f}, \tTrain Y Loss: {Y_loss:.3f},')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. Acc: {valid_acc*100:.2f}%')

	nonzero()
Consider using one of the following signatures instead:
	nonzero(*, bool as_tuple) (Triggered internally at  /opt/conda/conda-bld/pytorch_1603091395256/work/torch/csrc/utils/python_arg_parser.cpp:945.)
  


Epoch: 01 | Epoch Time: 0m 39s
	Train G Loss: 5.391, 	Train F Loss: nan, 	Train Y Loss: nan,
	 Val. Loss: nan |  Val. Acc: 0.00%
Epoch: 02 | Epoch Time: 0m 38s
	Train G Loss: nan, 	Train F Loss: nan, 	Train Y Loss: nan,
	 Val. Loss: nan |  Val. Acc: 0.00%


KeyboardInterrupt: 

In [None]:
model.load_state_dict(torch.load('tut1-model.pt'))

test_loss, test_acc = evaluate(model, test_iterator, criterion, TAG_PAD_IDX)

print(f'Test Loss: {test_loss:.3f} |  Test Acc: {test_acc*100:.2f}%')

In [114]:
a = torch.randn([1,2,3])
b = torch.tensor([0,1])

In [115]:
a

tensor([[[ 0.3412, -0.5604, -1.3266],
         [ 0.0555, -0.5062, -0.6640]]])

In [116]:
b

tensor([0, 1])

In [118]:
c = a * b.view(1, -1, 1)

In [119]:
c

tensor([[[ 0.0000, -0.0000, -0.0000],
         [ 0.0555, -0.5062, -0.6640]]])