In [29]:
import torch
import torchvision
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

In [371]:
def positional_encoder(d, N, R = 1000):
    """
    See:
        https://kazemnejad.com/blog/transformer_architecture_positional_encoding/
        https://en.wikipedia.org/wiki/Transformer_(machine_learning_model). 
    R should be much larger than d. Wikipedia calls R -> N.
    """
    r = np.power(R,2/d)
    ts = torch.range(0,N-1).repeat(int(d/2),1)
    ks = torch.transpose(torch.range(0,d/2-1).repeat(N,1), 0,1)
    thetas = ts/np.power(r,ks)
    return torch.stack(
        (torch.sin(thetas),torch.cos(thetas)), dim=1).view(d,N).transpose(0,1)

In [372]:
def create_mask(b, N):
        """
        bxNxN mask that is true for i <= j
        """
        return (torch.range(1,N).repeat(N,1) <= torch.range(1,N).repeat(N,1).transpose(0,1)) 
 

In [400]:
# (b = batch size, N = sequence length, M = embedding size)

class AttnHead(nn.Module):
    def __init__(self, N, M):
        super().__init__()
        self.WQ = nn.Linear(in_features = M, out_features = M, bias = False)
        self.WK = nn.Linear(in_features = M, out_features = M, bias = False)
        self.WV = nn.Linear(in_features = M, out_features = M, bias = False)
        self.N = N
        self.M = M
        
    def create_mask(self):
        """
        bxNxN mask that is true for i <= j
        """
        return (torch.range(1,self.N).repeat(self.N,1) <= torch.range(1,self.N).repeat(self.N,1).transpose(0,1)
            ).repeat(1,1,1)
 
    def dot_product_attn(self, Q, K, V):
        """
        Q, K, V should be NxM
        """
        d = K.shape[1]
        N = Q.shape[0]
        return torch.matmul(
            torch.nn.Softmax(1)(
                torch.matmul(Q, torch.transpose(K,1,2)) / np.sqrt(d)
            # ) * self.create_mask()
            ), 
            V) 
        
    def forward(self, x):
        x = x + positional_encoder(self.M, self.N)
        Q, K, V = self.WQ(x), self.WK(x), self.WV(x)
        return self.dot_product_attn(Q,K,V).mean(2)

In [406]:
# B = number of samples

class random_reverse(Dataset):
    def __init__(self, B, N, M):
        rand_data = torch.normal(mean = torch.zeros(B,N,1))
        self.data = rand_data.repeat(1,1,M)
        self.labels = self.data.mean(2).flip(1)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        x = self.data[index]
        y = self.labels[index]
        return x, y

In [None]:
rr = random_reverse(100000,3,10)
batch_size = 512
data_loader = DataLoader(rr, batch_size=batch_size, shuffle=True)

head = AttnHead(3,10)

criterion = nn.MSELoss()
optimizer = torch.optim.Adam(head.parameters(), lr=1e-4)

num_epochs = 5
print_every = 100

for epoch in range(1000):
    running_loss = 0.0
    for i, data in enumerate(data_loader, 0):
        inputs, labels = data

        optimizer.zero_grad()

        outputs = head(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if i % 100 == 99:
            print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 100))
            running_loss = 0.0