In [43]:
import torch
import torch.nn.functional as F
import torch.nn as nn
from transformers import BertLMHeadModel, AutoConfig
import torchvision
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter

#### Building a Pytroch model

##### Self-Attention
$Attention \left(Q, K, V\right) = softmax \left( \frac{QK^T}{\sqrt{d_k}} \right) \cdot V$

<!-- ![alternatvie text](attention.png) -->
<div>
<center>
<img src="attention.png" height="400"/>
<img src="multihead.png" height="400"/>
</center>
</div>

In [44]:
def scaled_dot_product(q, k, v):
    # (bs, head, seq, hs // head)
    d_k = q.shape[-1]
    attn_score = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(d_k)
    # (bs, head, seq, seq)
    attn_probs = F.softmax(attn_score, dim=-1)
    attn_probs = F.dropout(attn_probs, 0.1)
    # (bs, head, seq, hs // head)
    attn = torch.matmul(attn_probs, v)
    return attn

In [45]:
class SelfAttention(nn.Module):
    def __init__(self, hidden_size, n_heads):
        super().__init__()
        self.q_proj = nn.Linear(hidden_size, hidden_size)
        self.k_proj = nn.Linear(hidden_size, hidden_size)
        self.v_proj = nn.Linear(hidden_size, hidden_size)
        self.n_heads = n_heads

    def permute_for_scores(self, x):
        # x: (batch_size, seq_len, hidden_size)
        new_shape = x.shape[:-1] + (self.n_heads, -1)
        x = x.view(new_shape)
        # output: (bs, head, seq, hs // head)
        return x.permute(0, 2, 1, 3)

    def forward(self, hidden_states):
        # hidden_states: (batch_size, seq_len, hidden_size)
        # qkv layers
        q = self.permute_for_scores(self.q_proj(hidden_states))
        k = self.permute_for_scores(self.k_proj(hidden_states))
        v = self.permute_for_scores(self.v_proj(hidden_states))
        # core attention
        output = scaled_dot_product(q, k, v)
        # output: (bs, seq, head, hs // head)
        output.permute(0, 2, 1, 3)
        output.view(output.shape[0], output.shape[1], -1)
        return output

##### Attention Layer
<div>
<center>
<img src="transformer.png" width="400"/>
</center>
</div>

In [46]:
class Projection(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.dense = nn.Linear(hidden_size, hidden_size)
        self.dropout = nn.Dropout(0.1)
        self.layer_norm = nn.LayerNorm(hidden_size)

    def forward(self, hidden_states, input_tensor):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.layer_norm(hidden_states + input_tensor)
        return hidden_states

class MLP(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.linear1 = nn.Linear(hidden_size, hidden_size)
        self.activation = nn.GELU()
        self.linear2 = nn.Linear(hidden_size, hidden_size)

    def forward(self, data):
        out = self.linear1(data)
        out = self.activation(out)
        out = self.linear2(out)
        return out

class Attention(nn.Module):
    def __init__(self, hidden_size, n_heads):
        super().__init__()
        self.self_attn = SelfAttention(hidden_size, n_heads)
        self.proj1 = Projection(hidden_size)
        self.linear_net = MLP(hidden_size)
        self.proj2 = Projection(hidden_size)

    def forward(self, hidden_states):
        self_output = self.self_attn(hidden_states)
        attention_output = self.proj1(self_output, hidden_states)
        linear_out = self.linear_net(attention_output)
        linear_out = attention_output + self.dropout(linear_out)
        out = self.proj2(linear_out)
        return out

In [47]:
model = Attention(hidden_size=1024, n_heads=16)
print(model)

Attention(
  (self_attn): SelfAttention(
    (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
    (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
    (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
  )
  (proj1): Projection(
    (dense): Linear(in_features=1024, out_features=1024, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
    (layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  )
  (linear_net): MLP(
    (linear1): Linear(in_features=1024, out_features=1024, bias=True)
    (activation): GELU(approximate=none)
    (linear2): Linear(in_features=1024, out_features=1024, bias=True)
  )
  (proj2): Projection(
    (dense): Linear(in_features=1024, out_features=1024, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
    (layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  )
)


In [48]:
config = AutoConfig.from_pretrained("bert-large-uncased")
bert_model = BertLMHeadModel(config)


If you want to use `BertLMHeadModel` as a standalone, add `is_decoder=True.`


In [49]:
print(bert_model.bert.encoder.layer[0])

BertLayer(
  (attention): BertAttention(
    (self): BertSelfAttention(
      (query): Linear(in_features=1024, out_features=1024, bias=True)
      (key): Linear(in_features=1024, out_features=1024, bias=True)
      (value): Linear(in_features=1024, out_features=1024, bias=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (output): BertSelfOutput(
      (dense): Linear(in_features=1024, out_features=1024, bias=True)
      (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
  )
  (intermediate): BertIntermediate(
    (dense): Linear(in_features=1024, out_features=4096, bias=True)
    (intermediate_act_fn): GELUActivation()
  )
  (output): BertOutput(
    (dense): Linear(in_features=4096, out_features=1024, bias=True)
    (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
)


In [50]:
def train(model, device="cuda", bs=8, seq_length=512):
    input_ids = torch.ones(bs, seq_length, dtype=torch.long, device=device)
    attention_mask = torch.ones(bs, seq_length, dtype=torch.float16, device=device)
    token_type_ids = torch.ones(bs, seq_length, dtype=torch.long, device=device)
    labels = input_ids.clone()
    model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
    for step in range(100):
        inputs = (input_ids, attention_mask, token_type_ids)
        loss = model(*inputs, labels=labels).loss
        loss.backward()
        optimizer.step()

        if step % 10 == 0:
            print(f"step {step} loss: {loss.item()}")

#### Train a simple model

##### Define model

In [51]:
class Model(nn.Module):
    def __init__(self, num_classes):
        super(Model, self).__init__()
        self.feature_extractor = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=0),
            nn.ReLU(),
            nn.AvgPool2d(kernel_size=2, stride=2),
            nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1, padding=0),
            nn.ReLU(),
            nn.AvgPool2d(kernel_size=2, stride=2),
        )

        self.classifier = nn.Sequential(
            nn.Linear(256, 120),
            nn.ReLU(),
            nn.Linear(120, 84),
            nn.ReLU(),
            nn.Linear(84, num_classes),
            #nn.Softmax()
        )
        
    def forward(self, x): 
        x = self.feature_extractor(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x
    
model = Model(num_classes = 10)

#### Load dataset

In [52]:
def mnist(batch_size: int,
              data_dir: str = '/tmp/mnist_data'):
        normalize = transforms.Normalize(mean=[0.1307], std=[0.3081])
        transform = transforms.Compose([transforms.ToTensor(), normalize])
        trainset = torchvision.datasets.MNIST(
            root=data_dir, train=True, transform=transform, download=True)
        trainloader = torch.utils.data.DataLoader(
            trainset, batch_size=batch_size, shuffle=True)
        testset = torchvision.datasets.MNIST(
            root=data_dir, train=False, transform=transform, download=True)
        testloader = torch.utils.data.DataLoader(
            testset, batch_size=batch_size, shuffle=False)
        return trainloader, testloader

trainloader, testloader = mnist(batch_size = 512)

#### Pick a optimizer

In [53]:
def adam(model, init_lr, weight_decay=0.0):
    optimizer = torch.optim.SGD(
            model.parameters(),
            lr=init_lr,
            momentum=0.9,
            weight_decay=weight_decay)
    return optimizer

def sgd(model, init_lr, weight_decay=0.0):
        optimizer = torch.optim.SGD(
            model.parameters(),
            lr=init_lr,
            momentum=0.9,
            weight_decay=weight_decay)
        return optimizer

optimizer = sgd(model, 0.1)

#### Pick a scheduler

In [54]:
def multistep(optimizer, milestones, gamma=0.1):
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones, gamma=gamma)
    return scheduler

scheduler = multistep(optimizer, [5, 10])

In [55]:
tb_writer = SummaryWriter()

device = ("cuda" if torch.cuda.is_available() 
          else ("mps" if torch.backends.mps.is_available() else "cpu"))

criterion = nn.CrossEntropyLoss().to(device)

def test(model, testloader=testloader, device=device):
    # switch to evaluate mode
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for i, data in enumerate(testloader, 0):
            images, labels = data[0].to(device), data[1].to(device)
            outputs = model(images)
            # measure accuracy and record loss
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    return correct * 100 / total

def training_loop(model, optimizer, scheduler, num_epochs, device=device):
    global_step = 0
    losses = []
    accuracies = []
    model = model.to(device)
    for epoch in range(num_epochs):
        # set printing functions
        
        # switch the model to the training mode
        model.train()

        tb_writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], global_step)

        # each epoch
        for i, data in enumerate(trainloader, 0):
            # get the inputs; data is a list of [inputs, labels]
            inputs, labels = data[0].to(device), data[1].to(device)
            # zero the parameter gradients
            optimizer.zero_grad()
            # forward + backward + optimize
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            losses.append(loss.item())
            acc = test(model, trainloader)
            accuracies.append(acc)
            
            tb_writer.add_scalar('Loss/train', loss, global_step)
            tb_writer.add_scalar('Acc/train', acc, global_step)
            print(f"Training loss at step {global_step} is: {loss}")
            global_step += 1

        # update the learning rate
        scheduler.step()

        # evaluate the model every few epochs
        
        acc = test(model)
        print(f"Test acc at step {global_step} is: {acc}")
        
        tb_writer.add_scalar('Acc/eval', acc, global_step)


In [56]:
training_loop(model, optimizer, scheduler, num_epochs=10)

Training loss at step 0 is: 2.3106632232666016
Training loss at step 1 is: 2.303093433380127
Training loss at step 2 is: 2.2949700355529785
Training loss at step 3 is: 2.2950339317321777
Training loss at step 4 is: 2.289426803588867
Training loss at step 5 is: 2.286762237548828
Training loss at step 6 is: 2.2789368629455566
Training loss at step 7 is: 2.2720489501953125
Training loss at step 8 is: 2.2601583003997803
Training loss at step 9 is: 2.244412422180176
Training loss at step 10 is: 2.224745988845825
Training loss at step 11 is: 2.1894540786743164
Training loss at step 12 is: 2.1639657020568848
Training loss at step 13 is: 2.0337555408477783
Training loss at step 14 is: 1.9824663400650024
Training loss at step 15 is: 1.7275285720825195
Training loss at step 16 is: 1.5443382263183594
Training loss at step 17 is: 1.4508697986602783
Training loss at step 18 is: 5.747685432434082
Training loss at step 19 is: 4.923940658569336
Training loss at step 20 is: 2.272399663925171
Training l