In [1]:
from crf import CRF
from transformers import RobertaModel, RobertaTokenizer
import torch.nn as nn
import numpy as np
import torch
from torch.nn.functional import log_softmax

In [2]:
ROBERTA_PATH = '/home/dzigen/Desktop/medics2023/NLP_MODULE/models/RuBioRoBERTa'
DEVICE='cpu'
LAYERS_TO_HOLD = ['23', '22', '21', '20', '19', 'pooler']

In [3]:
class RobertaCRF(nn.Module):
    def __init__(self, label_size, roberta_path=ROBERTA_PATH, device=DEVICE):
        super(RobertaCRF, self).__init__()

        self.encoder = RobertaModel.from_pretrained(roberta_path).to(device)
        self.dropout = nn.Dropout(0.5)     
        self.linear = nn.Linear(self.encoder.config.hidden_size, label_size)
        self.crf = CRF(label_size)

        # Замораживаем часть слоёв бекбона
        for name, param in self.encoder.named_parameters():
            param.requires_grad = False
        for name, param in self.encoder.named_parameters():
            for hold_l in LAYERS_TO_HOLD:
                if hold_l in name:
                    param.requires_grad = True

    def forward(self, input_ids, attention_mask, labels=None, mode='train'):
        embeddings = self.encoder(input_ids=input_ids, attention_mask= attention_mask)
        drop_out = self.dropout(embeddings.last_hidden_state)
        linear_out = self.linear(drop_out)
        log_out = log_softmax(linear_out, dim=-1)

        if mode == 'train':
            crf_out = self.crf(log_out, mask=attention_mask, 
                               labels=labels)
        elif mode == 'eval':
            crf_out = self.crf.viterbi_decode(log_out, mask=attention_mask)
        else:
            raise KeyError
        
        return crf_out

In [4]:
model = RobertaCRF(21)

In [12]:
encoder = RobertaModel.from_pretrained(ROBERTA_PATH).to(DEVICE)  

In [5]:
input_ids = torch.tensor(np.random.randint(100, size=(2,514)))
mask = torch.ByteTensor([[1]*100 + [0]*414]*2)
labels = torch.tensor([list(np.random.randint(3, size=100)) + [0]*414]*2)

In [37]:
input_ids.size()

torch.Size([2, 514])

In [41]:
labels.size()

torch.Size([2, 514])

In [38]:
mask.size()

torch.Size([2, 514])

In [57]:
output = log_softmax(torch.randn(10, 10), dim=-1)

In [58]:
target = torch.randint(0, 10, (10,))

In [59]:
loss = criterion(output, target)
print(loss)

tensor(2.4786)


In [None]:
criterion()

In [8]:
import warnings
warnings.filterwarnings("ignore")

In [9]:
out = model(input_ids, mask, labels, mode='train')

In [10]:
out

tensor([346.0506, 331.8098], grad_fn=<SubBackward0>)

In [9]:
out.mean()

tensor(339.4282, grad_fn=<MeanBackward0>)

In [34]:
out.backward()

In [25]:
output = encoder(input_ids=input_ids)

In [105]:
output

BaseModelOutputWithPoolingAndCrossAttentions(last_hidden_state=tensor([[[ 0.6007, -0.3789,  1.6257,  ..., -0.2636,  0.1943, -0.9833],
         [-0.2066,  1.4493, -0.2869,  ...,  0.4865,  1.6656,  0.2705],
         [-0.0813,  3.3444,  0.7160,  ..., -0.0481,  1.0002, -0.3339],
         ...,
         [-0.5181,  1.2309,  0.9647,  ...,  0.3538,  0.5873, -0.2277],
         [ 0.0342,  1.9868, -0.3185,  ...,  0.5481,  2.0470,  0.2225],
         [ 0.0752,  2.7654,  0.1172,  ..., -0.3603,  0.8225,  0.3939]]],
       grad_fn=<NativeLayerNormBackward0>), pooler_output=tensor([[ 0.0290,  0.5206, -0.6405,  ...,  0.5000, -0.3417, -0.4772]],
       grad_fn=<TanhBackward0>), hidden_states=None, past_key_values=None, attentions=None, cross_attentions=None)

In [34]:
linear = nn.Linear(encoder.config.hidden_size, 21)

In [37]:
l_out = linear(output.last_hidden_state)

In [47]:
l_out.size()

torch.Size([1, 514, 21])

In [62]:
crf = CRF(21)

In [64]:
loss = crf(l_out,labels=labels,mask=mask)

In [70]:
crf_out = crf.viterbi_decode(l_out,mask=mask)

In [73]:
len(crf_out[0])

100

In [None]:
labels=labels,