In [None]:
import torch
import torch.nn as nn

class Attention(nn.Module):
    def __init__(self, input_size, hidden_size):
        self.Encoder = nn.GRU(input_size, hidden_size, bidirectional = True)
        self.MLP = nn.Sequential(nn.Linear(2 * hidden_size, 2 * hidden_size), nn.Tanh())
        self.context_vector = nn.Parameter(torch.randn(2 * hidden_size))
        self.softmax = nn.Softmax()
        
    def forward(self,x):
        h,_ = self.Encoder(x)
        u = self.MLP(h)
        alpha = self.softmax(u @ self.context_vector)
        output = torch.sum(alpha * h)
        
        return output
        

In [None]:
class HAN(nn.Module):
    def __init__(self, word_embed_size, input_size, word_hidden_size, sentence_hidden_size, num_classes):
        self.We = nn.Parameter(torch.randn([word_embed_size, input_size]))
        self.WordAttention = Attention(input_size, word_hidden_size)
        self.SentenceAttention = Attention(2 * word_hidden_size, sentence_hidden_size)
        self.classifier = nn.Sequential(nn.Linear(2 * sentence_hidden_size, num_classes), nn.Softmax())
        
    def forward(self, article):
        article = article @ self.We
        s = torch.stack([self.WordAttention(sentence) for sentence in article])
        v = self.SentenceAttention(s)
        p = self.classifier(v)
        
        return p 

In [None]:
trainloader = DataLoader(trainset, batch_size= BATCH_SIZE, 
                         shuffle= False, pin_memory= True)
valloader = DataLoader(valset, batch_size= BATCH_SIZE, 
                         shuffle= False, pin_memory= True)
testloader = DataLoader(testset, batch_size= BATCH_SIZE, 
                         shuffle= False, pin_memory= True)

In [None]:
def test(data_loader, model, loss_function):
    model.eval()
    Loss = 0
    correct = 0
    for input, target in enumerate(data_loader):
        output = model(input)
        Loss += loss_function(output, label)
        correct += (output.argmax(0) == y).type(torch.float).sum().item()
        
    Loss /= len(data_loader.dataset)
    accuracy = correct / len(data_loader.dataset)
    
    return Loss, accuracy


def train(data_loader, model, optimizer, scheduler, loss_function):
    model.train()
    Loss = []
    for i, (input, label) in enumerate(data_loader)
        output = model(input)
        loss = loss_function(output, label)
        loss.backward()
        optimizer.zero_grad()
        optimizer.step()
        scheduler.step()

        Loss.append(loss.cpu().detach().numpy())
        print("loss {}".format(Loss[-1]))

In [None]:
model = HAN(word_embed_size, input_size, word_hidden_size, sentence_hidden_size, num_classes).to(DEVICE)

lr = 1e-3
epochs = 3
iteration = epochs * len(trainloader)
optimizer = optim.Adam(model.parameters(), lr = lr)
loss_function = torch.nn.NLLLoss()
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max = iteration)

iter_loss = []
epoch_loss = []
best_acc = 0
for t in range(epochs):
    print(f'Epoch {t} starts.')
    tloss = train(trainloader, model, optimizer, lr_scheduler, loss_function)
    val_loss, val_acc = test(valloader, model, LOSS_FN)
    
    iter_loss = iter_loss + tloss
    epoch_loss.append(sum(tloss) / len(tloss))
    
    print(f'Epoch {t}: LOSS = {epoch_loss[-1]}, VAL-ACC = {val_acc}')
    
torch.save(model.state_dict(), f'HAN_last.pth')
    
print(model)