# RUBERT

In [None]:
import os
from collections import defaultdict, Counter
import torch
import torch.nn.functional as F
import pandas as pd
import numpy as np
from transformers import BertForSequenceClassification, BertTokenizer
import pytorch_lightning as pl
from tqdm.notebook import tqdm
from sklearn.preprocessing import LabelEncoder
from transformers import AutoTokenizer, AutoModelForSequenceClassification


batch_size=14
#Мы попробовали еще DeepPavlov/rubert-base-cased-conversational
model_name = 'DeepPavlov/rubert-base-cased'
train = pd.read_csv('../input/ruatd2022/train.csv')
test = pd.read_csv('../input/attest/test.csv')
val = pd.read_csv('../input/atdval/val.csv')

le = LabelEncoder() # закодируем лейблы 
le.fit(train['Class'].values)

In [None]:
tokenizer = BertTokenizer.from_pretrained(model_name)
#tokenizer.pad_token = tokenizer.eos_token
def collate_fn(input_data):
    texts, labels = zip(*input_data)
    labels = torch.LongTensor(labels)
    inputs = tokenizer(texts, return_tensors='pt', padding='longest', max_length=256, truncation=True)
    inputs['Class'] = labels
    return inputs

class TextDataset(torch.utils.data.Dataset):
    def __init__(self, data, sort=False, le=None):
        super().__init__()
        self.texts = data['Text'].values
        if 'Class' in data.columns: # если есть разметка
            assert not data['Class'].isnull().any(), "Some labels are null"
            if le is not None:
                self.labels = le.transform(data['Class'])
            else:
                self.labels = data['Class'].values
        
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        if hasattr(self, 'labels'):
            return self.texts[idx], self.labels[idx]
        else:
            return self.texts[idx], []

class Metric: # metric class for storing metrics (accuracy, loss)
    def __init__(self):
        self.storage = defaultdict(list)
    
    def store(self, **kwargs):
        for key in kwargs:
            self.storage[key].append(kwargs[key])
            
    def reset(self):
        self.storage.clear()
        
    def log(self):
        for key in self.storage:
            self.storage[key] = np.mean(self.storage[key])
        return self.storage.items()
        
class BertClassifier(pl.LightningModule):
    def __init__(self, model_name, lr=1e-5, num_labels=2):
        super().__init__()
        #self.bert = BertForSequenceClassification.from_pretrained(model_name, num_labels=num_labels)
        self.bert = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=num_labels)
        #self.bert = AutoModelForCausalLM.from_pretrained(model_name, num_labels=num_labels)
        self.metric = Metric()
        self.learning_rate = lr
        
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.bert.parameters(), lr=self.learning_rate)
        return optimizer
        
    def forward(self, x):
        return self.bert(**x)
    
    def training_step(self, batch, batch_idx):
        labels = batch.pop('Class')
        logits = self.bert(**batch).logits
        loss = F.cross_entropy(logits, labels)
        predictions = logits.argmax(axis=1)
        accuracy = torch.mean((predictions == labels).double())
        self.metric.store(loss=loss.item(), accuracy=accuracy.item())
        if batch_idx % 100: # every 100 batches - log metrics (mean of last 100 batches)
            for k,v in self.metric.log():
                self.log(f'train/{k}', v)
            self.metric.reset()
        return loss
    
    def validation_step(self, batch, batch_idx):
        labels = batch.pop('Class')
        logits = self.bert(**batch).logits
        loss = F.cross_entropy(logits, labels)
        self.log('val/loss', loss)
        predictions = logits.argmax(axis=1)
        self.log('val/accuracy', torch.mean((predictions == labels).double()))

In [None]:
train = TextDataset(train, le=le)
val = TextDataset(val, le=le)
test = TextDataset(test, le=le)

train_loader = torch.utils.data.DataLoader(train, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
val_loader = torch.utils.data.DataLoader(val, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
test_loader = torch.utils.data.DataLoader(test, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

#tokenizer = BertTokenizer.from_pretrained(model_name).cuda()

model = BertClassifier(model_name, num_labels=len(le.classes_)).cuda()

for name, param in model.named_parameters():
    if name.startswith("bert.encoder.layer.1"):
        param.requires_grad = False
    if name.startswith("bert.encoder.layer.2"):
        param.requires_grad = False
    if name.startswith("bert.encoder.layer.3"):
        param.requires_grad = False
    if name.startswith("bert.encoder.layer.4"):
        param.requires_grad = False

version = f"{model_name}_binary"
logger = pl.loggers.TensorBoardLogger(save_dir=os.getcwd(), name='lightning_logs', version=version)
trainer = pl.Trainer(
    logger=logger, 
    gpus=[0],
    max_epochs=3, 
    num_sanity_val_steps=1
)
trainer.fit(model, train_loader, val_loader)

In [None]:
model.cuda()
def get_accuracy_and_pred(model, loader): # используйте эту функцию для получения accuracy и предсказаний
    preds = []
    model.eval()
    labels = None
    accs = 0
    ns = 0
    for batch in tqdm(loader):
        for key in batch:
            #print(key)
            batch[key] = batch[key].to(model.device)
        #print(batch)
        labels = batch.pop('Class')
        #print(labels)

        with torch.no_grad():
            pred = model(batch).logits.argmax(axis=1)
        #print(pred)
        #print(labels.size())
        if labels.size()[1] > 0:
            #print(labels)
            accs += torch.sum((pred == labels).double())
        preds.append(pred.cpu().numpy())
        ns += len(pred)
        
        #print(accs, ns)
    return accs/ns, np.concatenate(preds)

acc, preds = get_accuracy_and_pred(model, test_loader)

In [None]:
np.save('test_preds_rubert_based_frozen.npy', le.inverse_transform(preds))
print(f"Test accuracy: {acc}")

In [None]:
f = np.load('test_preds_rubert_based_frozen.npy', allow_pickle=True)


df = pd.read_csv('../input/sample/rubert-base-cased-conversational_128.csv', index_col=False)
df.drop('Class', axis = 1, inplace = True)

df['Class'] = le.inverse_transform(preds)
df.to_csv('submitrubert_bert_frozen', index=False)