In [161]:
import torch
import numpy as np
import torch.nn as nn
import math
import os
import pandas as pd
import torch.optim as optim
import torch.nn.functional as F
import json
from torch import nn, Tensor

MAX_LENGTH = 520
tokens = json.load(open("tokens.txt"))
model = torch.load("embedding_model")
file_path = os.path.join(os.path.expanduser("~"), "Downloads", "mental_health.csv")
orig_dataset = pd.read_csv(file_path)
orig_dataset = orig_dataset.to_numpy()
trainingSet =  orig_dataset[np.random.choice(orig_dataset.shape[0], 4, replace=True)] #extract training set
contextSet = [trainingSet[i][0] for i in range(len(trainingSet))] 
responseSet = [trainingSet[i][1] for i in range(len(trainingSet))]
#tokenizing the context and response set, also 0 is special token for unknown word
contextSet_tokenized = [[tokens[word] if word in tokens else 0 for word in example.split()] 
                         for example in contextSet ]
responseSet_tokenized = [[tokens[word] if word in tokens else 0 for word in example.split()] 
                         for example in responseSet ]

#convert token to input embedding for context and response set filled with padding
#tokens if end of sentence. 1 is special token for padding 
contextSet_embedding = []        
for context in contextSet_tokenized:
    contextEmbedding = []
    for i in range(MAX_LENGTH):
        if i>= len(context):
            contextEmbedding.append(model["embeddings.weight"][1])
            continue
        contextEmbedding.append(model["embeddings.weight"][context[i]])
    contextEmbedding = torch.stack(contextEmbedding)
    contextSet_embedding.append(contextEmbedding[:])

contextSet_embedding = torch.stack(contextSet_embedding)

responseSet_embedding = []        
for response in responseSet_tokenized:
    responseEmbedding = []
    for i in range(MAX_LENGTH):
        if i>= len(response):
            responseEmbedding.append(model["embeddings.weight"][1])
            continue
        responseEmbedding.append(model["embeddings.weight"][response[i]])
    responseEmbedding = torch.stack(responseEmbedding)
    responseSet_embedding.append(responseEmbedding[:])
    
responseSet_embedding = torch.stack(responseSet_embedding)
print(responseSet_embedding.shape)


torch.Size([4, 520, 252])


In [162]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: Tensor) -> Tensor:
        """
        Arguments:
            x: Tensor, shape ``[seq_len, batch_size, embedding_dim]``
        """
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)

In [163]:
pe = PositionalEncoding(252, max_len = MAX_LENGTH)
pe.forward(contextSet_embedding)

tensor([[[-0.0000,  0.6529, -0.4049,  ..., -0.3518,  0.7513,  2.7533],
         [-0.9623, -0.9468,  0.0000,  ...,  0.1233, -0.4650,  1.2283],
         [ 0.6168,  2.7969, -1.3155,  ...,  1.8697, -0.9027, -0.1001],
         ...,
         [-0.4917,  0.4138,  1.2308,  ...,  0.4614, -1.6518,  0.2978],
         [-0.4917,  0.4138,  1.2308,  ...,  0.4614, -1.6518,  0.2978],
         [-0.4917,  0.4138,  1.2308,  ...,  0.4614, -1.6518,  0.2978]],

        [[-0.0273, -1.4576,  0.9074,  ...,  0.1233, -0.4649,  1.2283],
         [ 1.3544,  1.4590, -0.0797,  ...,  1.8897,  0.7357, -0.4069],
         [ 0.9604,  1.1203, -0.0798,  ...,  1.3011,  0.2714,  1.9126],
         ...,
         [ 0.4433, -0.0969,  2.1212,  ...,  0.4614, -1.6516,  0.2978],
         [ 0.4433, -0.0000,  2.1212,  ...,  0.4614, -1.6516,  0.2978],
         [ 0.4433, -0.0969,  2.1212,  ...,  0.4614, -1.6516,  0.2978]],

        [[ 0.4201, -0.0000, -0.0000,  ...,  0.0664,  1.2768,  0.2659],
         [ 0.0000,  0.6150, -0.0631,  ...,  0