In [None]:
import torch
from transformers import AutoModel

##### Q2. Implement Baseline - Adding adapter module to all transformer layers #####################
 
"""
1: Implementing 'Top' Adapter
NOTE:
    In this section, your task is to design top adapter (similar to 'head') for the given datasets. 
    You have the flexibility to design the top adapter as you see fit, 
    but it is advised to avoid excessive complexity or resource-intensive configurations.
    
    Please refer to the [BERT] class provided below.

CONSTRAINTS:
    The number of output classes differ depending on the dataset. 
    Here are the specifications for each dataset:
    - BERT encoder layer produces an output shape of [B, 512, 768], 
      where B represents the batch size, 512 represents the sequence length, and 768 represents the embedding dimension. 
    - For the IMDB dataset, the model should have [2 output class], and the loss function is CrossEntropyLoss.
    - For the SNLI dataset, the model should have [3 output classes], and the loss function is CrossEntropyLoss.
    - For the AGNews dataset, the model should have [4 output classes], and the loss function is CrossEntropyLoss.
    
    Keep in mind that an activation layer at the end of your implementation is NOT necessary. 
    This is due to the fact that the proposed loss functions already have built-in sigmoid or softmax functions.
    
HINT: 
    To implement the top adapter, 
    you might want to consider using a combination of [torch.nn.Linear] and activation layers (e.g., torch.nn.GELU())
    or utilizing [torch.nn.Sequential].
"""

# ==== 1. Implement top adapter  =========================================
class TopAdapter(torch.nn.Module):
    def __init__(self, num_classes=2):
        super().__init__()

        # Q3. Modify - Changing the size of adapter (hidden unit):
        self.adapter = torch.nn.Sequential(
            ##### Fill here #####
            torch.nn.Linear(768, num_classes),
            torch.nn.GELU()
            #####################
        )

    def forward(self, x):
        return self.adapter ##### Fill here #####

# ==== 1. Implement top adapter  =========================================



"""
2: Implementing 'Layer-Wise' Adapter
NOTE:
    In this section, your task is to design layer-wise adapter for BERT's encoder layers. 
    You have the flexibility to design the layer-wise adapter as you see fit, 
    but it is advised to avoid excessive complexity or resource-intensive configurations.
    
    Please refer to the [BERT] class provided below.

CONSTRAINTS:
    Please note that the dimensions for both input and output are 768,
    and the BERT (base) model contains 12 layers.
    
HINT: 
    Consider employing [torch.nn.ModuleList] for the implementation of the layer-wise adapter.
    
"""

# ==== 2. Implement layer-wise adapter  =========================================
class LayerWiseAdapter(torch.nn.Module):
    def __init__(self, num_adapters=12):
        super().__init__()

        # Q4. Modify - Adding adapter module to the top half layers only:
        # HINT : original number of BERT layers is 12
        self.num_adapters  = num_adapters
        
        # Q3. Modify - Changing the size of adapter (hidden unit):
        self.adapter = torch.nn.ModuleList([
            *[torch.nn.Sequential(
                ##### Fill here #####
                torch.nn.Linear(768, 768),
                torch.nn.GELU(),
                #####################
            ) for _ in range(self.num_adapters)]
        ])

    def forward(self, x, idx):
        return self.adapter[idx](x) ##### Fill here #####

    def __len__(self):
        return self.num_adapters

# ==== 2. Implement layer-wise adapters  =========================================
##### Q2 #########################################################################



# ====================================================================
class BERT(torch.nn.Module):
    def __init__(self, 
                 top_adapter:torch.nn.Module=None, 
                 layer_wise_adapters:torch.nn.ModuleList=None,
                 num_classes=1,
                 num_adapters=12,
                 full_finetuning=False
                 ):
        super().__init__()
        self.full_finetuning = full_finetuning
        self.num_adapters = num_adapters

        """load pretrained BERT """
        PLM = AutoModel.from_pretrained("bert-base-uncased") 
        embeddings = PLM.embeddings
        embeddings.eval()
        layers = PLM.encoder.layer # torch.nn.ModuleList
        layers.eval()

        """Our adaptable BERT model with interchangeable adapter layers"""
        self.embeddings = embeddings
        self.layers = torch.nn.ModuleList(
            [layers[idx] for idx in range(len(layers))]
        )
        self.pooler = torch.nn.Linear(768, num_classes)

        # Q1. Find and change - Full fine-tuning setting
        # HINT: use 'requires_grad' method

        # Freeze pretrained BERT embeddings
        for param in self.embeddings.parameters():
            param.requires_grad = self.full_finetuning

        # Freeze pretrained BERT layers
        for layer in self.layers:
            for param in layer.parameters():
                param.requires_grad = self.full_finetuning
                
        # interchangeable adapter layers
        self.top_adapter = top_adapter # torch.nn.Module (single module)
        self.layer_wise_adapters = layer_wise_adapters # torch.nn.ModuleList 
        if layer_wise_adapters is not None:
            self.num_layer_wise_adapters = len(self.layer_wise_adapters) # int
        else:
            self.num_layer_wise_adapters = 0
    
    # Q4. Modify - Adding adapter module to the top half layers only:
    def forward(self, input_ids, attention_mask):
        extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2).to(dtype=attention_mask.dtype)
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
        
        # Convert the input tokens into embeddings
        x = self.embeddings(input_ids) 
        # Process the embeddings through each layer in the model's layer stack
        for idx in range(len(self.layers)):
            # Apply the layer and the attention mask to the embeddings to generate hidden states
            x = self.layers[idx](x,attention_mask=extended_attention_mask)[0]
            
            # If layer-wise adapters are defined, 
            # apply the corresponding adapter to the layer's output
            # (2. Your layer_wise_adapters will be applied here)

            if self.layer_wise_adapters is not None:
                x = self.layer_wise_adapters(x, idx=self.num_adapters)
        
        last_hidden_state = x[:, 0, :]
        if self.top_adapter is not None:
            # If a top adapter is defined, apply it to the output of the final layer
            # (1. Your top_adapter will be applied here)
            out = self.top_adapter(last_hidden_state)
        else:
            # If no top adapter is defined, 
            # apply the pooling layer to the output of the final layer instead
            out = self.pooler(last_hidden_state)
        
        return out
# ====================================================================
