# Encoder Classifer
To sanity-check the transformer encoder, we'll build a simple classifier and make sure the model is able to overfit the data.

In [21]:
from vit.encoder import TransformerEncoder

Now we build the classifer by adding a classifier head to the transformer encoder

In [16]:
import torch 
from torch import nn
from vit.encoder import TransformerEncoder
class TransformerEncoderClassifer(nn.Module):
    def __init__(
        self, 
        vocab_size, 
        d_model, 
        num_layer, 
        num_head, 
        d_k, 
        dropout_rate,
        num_class,
    ) -> None:
        super().__init__()
        self.encoder = TransformerEncoder(
            vocab_size=vocab_size, 
            d_model=d_model, 
            num_layer=num_layer, 
            num_head=num_head, 
            d_k=d_k, 
            dropout_rate=dropout_rate,
        )
        self.pre_classifier = nn.Linear(d_model, d_model)
        self.classifier = nn.Linear(d_model, num_class)
        self.dropout = nn.Dropout(dropout_rate)
    
    def forward(self, x, attn_mask=None):
        # x: [batch_size, seq_len] -> [batch_size, seq_len, d_model]
        x = self.encoder(x, attn_mask=attn_mask)
        # Here we use the hidden_state of the first token as the input for
        # classification
        # x: [batch_size, seq_len, d_model] -> pooled_x: [batch_size, d,model]
        pooled_x = x[:, 0]
        pooled_x = self.pre_classifier(pooled_x)
        pooled_x = nn.ReLU()(pooled_x)
        pooled_x = self.dropout(pooled_x)
        outputs = self.classifier(pooled_x)
        return outputs


In [17]:
def classifier_unit_test():
    import torch
    # from vit.encoder import TransformerEncoder
    classifier = TransformerEncoderClassifer(1000, 512, 6, 8, 64, 0.1, 2).to('cpu') #[d_model, num_head, d_k]
    dummy_inputs = torch.randint(0, 1000, (2, 128)).to('cpu') #[batch_size, seq_len]
    dummy_attn_masks = torch.randint(0, 2, (2, 128)).to('cpu')
    y = classifier(dummy_inputs)
    loss = y.mean()
    loss.backward()
    for name, param in classifier.named_parameters():
        # if param.grad is None:
        #     print(name)
        assert param.grad is not None

In [18]:
classifier_unit_test()

In [19]:
import torch
# from vit.encoder import TransformerEncoder
classifier = TransformerEncoderClassifer(1000, 512, 6, 8, 64, 0.1, 2).to('cpu') #[d_model, num_head, d_k]
dummy_inputs = torch.randint(0, 1000, (2, 128)).to('cpu') #[batch_size, seq_len]
dummy_attn_masks = torch.randint(0, 2, (2, 128)).to('cpu')
y = classifier(dummy_inputs)

In [20]:
y

tensor([[-0.5664, -0.0010],
        [-0.2738,  0.2188]], grad_fn=<AddmmBackward0>)