In [5]:
import torch
import torch.nn as nn
import transformers
from transformers import (
    AutoConfig,
    AutoModel,
)

In [6]:
class LitModel(nn.Module):
    def __init__(self, model_name_or_path="roberta-base"):
        super().__init__()

        self.config = AutoConfig.from_pretrained(model_name_or_path)
        self.config.update({
            "output_hidden_states":True, 
            "hidden_dropout_prob": 0.0,
            "layer_norm_eps": 1e-7
        })                       
        
        self.roberta = AutoModel.from_pretrained(model_name_or_path, config=self.config)  
        
        hidden_size = self.config.hidden_size
        self.attention = nn.Sequential(            
            nn.Linear(hidden_size, 512),            
            nn.Tanh(),                       
            nn.Linear(512, 1),
            nn.Softmax(dim=1)
        )        

        self.regressor = nn.Sequential(                        
            nn.Linear(hidden_size, 1)                        
        )

        self._init_embed_layers(reinit_layers=4)

    def _init_embed_layers(self, reinit_layers: int = 4):
        if reinit_layers > 0:
            for layer in self.roberta.encoder.layer[-reinit_layers:]:
                for module in layer.modules():
                    if isinstance(module, nn.Linear):
                        module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
                        if module.bias is not None:
                            module.bias.data.zero_()
                    elif isinstance(module, nn.Embedding):
                        module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
                        if module.padding_idx is not None:
                            module.weight.data[module.padding_idx].zero_()
                    elif isinstance(module, nn.LayerNorm):
                        module.bias.data.zero_()
                        module.weight.data.fill_(1.0)
        

    def forward(self, input_ids, attention_mask):
        roberta_output = self.roberta(input_ids=input_ids, attention_mask=attention_mask)

        last_layer_hidden_states = roberta_output.hidden_states[-1]
        weights = self.attention(last_layer_hidden_states)
        context_vector = torch.sum(weights * last_layer_hidden_states, dim=1)
        # Now we reduce the context vector to the prediction score.
        return self.regressor(context_vector)

In [34]:
model = LitModel("roberta-base")


In [69]:
def create_optimizer(model):
    named_parameters = list(model.named_parameters())
    
    roberta_parameters = named_parameters[:197]    
    attention_parameters = named_parameters[199:203]
    regressor_parameters = named_parameters[203:]
        
    attention_group = [params for (name, params) in attention_parameters]
    regressor_group = [params for (name, params) in regressor_parameters]

    parameters = []
    parameters.append({"params": attention_group})
    parameters.append({"params": regressor_group})

    for layer_num, (name, params) in enumerate(roberta_parameters):
        weight_decay = 0.0 if "bias" in name else 0.01

        lr = 2e-5

        if layer_num >= 69:        
            lr = 5e-5

        if layer_num >= 133:
            lr = 1e-4

        parameters.append({"params": params, "weight_decay": weight_decay, "lr": lr})

    return AdamW(parameters)

In [71]:
past_result = create_optimizer(model)

In [72]:
def create_optimizer(model):
    named_parameters = list(model.named_parameters())    
    roberta_parameters = [(n, p) for n, p in named_parameters if 'roberta' in n]
    not_roberta_parameters = [(n, p) for n, p in named_parameters if 'roberta' not in n]

    not_roberta_group = [p for n, p in not_roberta_parameters]

    parameters = []
    parameters.append({"params": not_roberta_group})

    group_1 = [f"layer.{i}" for i in range(0, 5)]
    group_2 = [f"layer.{i}" for i in range(5, 9)]
    group_3 = [f"layer.{i}" for i in range(9, 12)]
    for name, params in roberta_parameters:
        weight_decay = 0.0 if "bias" in name else 0.01

        if any([(g in name) for g in group_1]):
            lr = 2e-5
        elif any([(g in name) for g in group_2]):
            lr = 5e-5
        elif any([(g in name) for g in group_3]):
            lr = 1e-4
        else:
            lr = 1e-4

        parameters.append({"params": params, "weight_decay": weight_decay, "lr": lr})
        
    return AdamW(parameters)

In [73]:
new_result = create_optimizer(model)