In [None]:
!pip install transformers datasets sentencepiece --quiet

In [7]:
import pandas as pd
import torch

In [8]:
import datasets

In [9]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification

In [10]:
from bs4 import BeautifulSoup

In [11]:
from sklearn.model_selection import train_test_split

# The data

In [13]:
data = pd.read_excel('export-26-10-20.xlsx', engine='openpyxl')
print(data.shape)

(614, 26)


In [14]:
data.isnull().mean()

Article Name                 0.000000
PMID                         0.000000
Keywords                     0.700326
Date of Publication          0.003257
Abstract                     0.000000
Journal                      0.004886
Research Model               0.000000
Indication                   0.001629
MESH                         1.000000
Compound                     0.001629
Drug                         0.897394
Effect                       0.000000
Route of Administration      0.651466
Dose                         0.674267
Regimen                      0.926710
Side Effects                 0.907166
Minimum Age                  0.908795
Maximum Age                  0.918567
Gender                       0.824104
Ethnicity                    0.946254
Number of Patients           0.701954
Type of Clinical Trial       0.802932
Comorbidity                  0.871336
Drug Given in Combination    0.877850
Cell Line                    0.827362
Animal                       0.543974
dtype: float

In [15]:
data.dropna(subset={'Indication', 'Compound'}, inplace=True)
print(data.shape)

(612, 26)


In [16]:
def get_abstact_text(text):
    return BeautifulSoup(text).text

In [17]:
data.apply(lambda row: row.Indication.lower() in row.Abstract.lower(), axis=1).mean()

0.6405228758169934

In [18]:
data.apply(lambda row: row.Compound.lower() in row.Abstract.lower(), axis=1).mean()

0.8529411764705882

In [19]:
data.sample(3)

Unnamed: 0,Article Name,PMID,Keywords,Date of Publication,Abstract,Journal,Research Model,Indication,MESH,Compound,Drug,Effect,Route of Administration,Dose,Regimen,Side Effects,Minimum Age,Maximum Age,Gender,Ethnicity,Number of Patients,Type of Clinical Trial,Comorbidity,Drug Given in Combination,Cell Line,Animal
177,Comparison of the analgesic effects of dronabi...,23609132,,2013.0,Recent studies have demonstrated the therapeut...,Neuropsychopharmacology : official publication...,clinical trial,Pain,,THC,,positive,,,,,,,both,,30.0,"randomized, placebo-controlled, double-dummy, ...",,,,
402,Cannabidiol increases survival and promotes re...,25595981,,2015.0,Cannabidiol increases survival and promotes re...,Neuroscience,in vivo,Cerebral malaria,,CBD,,positive,Injected IP,30mg/kg/day,,,,,female,,,,,Artesunate,,mice
562,Single and combined effects of plant-derived a...,29338068,,2019.0,"<p style=""box-sizing: inherit; line-height: 1....",British journal of pharmacology,in vivo,Cognitive impairment,,CBD,,NONE,,,,,,,,C57BL/6,,,,,,mice


In [20]:
data['text_lhs'] = data.apply(lambda row: f'Compound: {row.Compound}; Indication: {row.Indication}', axis=1)

In [21]:
data.text_lhs.value_counts().head(10)

Compound: THC; Indication: Pain                      30
Compound: CBD; Indication: Epilepsy                  26
Compound: CBD; Indication: Inflammation              24
Compound: CBD; Indication: Fear Anxiety PTSD         23
Compound: THC+CBD; Indication: Multiple sclerosis    20
Compound: THC; Indication: Nausea and vomiting       19
Compound: THC; Indication: Inflammation              14
Compound: CBD; Indication: Schizophrenia             14
Compound: THC; Indication: Fear Anxiety PTSD         13
Compound: THC; Indication: Safety                    11
Name: text_lhs, dtype: int64

In [22]:
data.text_lhs.sample(5)

278        Compound: THC; Indication: Cerebral ischaemia
393                      Compound: THC; Indication: Pain
396             Compound: Whole Leaf; Indication: Stress
585     Compound: THC; Indication: Tolerance development
596    Compound: Synthetic Cannabinoids; Indication: ...
Name: text_lhs, dtype: object

In [23]:
data['text_rhs'] = data.apply(lambda row: f'Title: {row["Article Name"]}\n{get_abstact_text(row.Abstract)}', axis=1)

In [24]:
print(data.text_rhs.sample(1).iloc[0])

Title: Anti-inflammatory and antioxidant effects of a combination of cannabidiol and moringin in LPS-stimulated macrophages.
Inflammatory response plays an important role in the activation and progress of many debilitating diseases. Natural products, like cannabidiol, a constituent of Cannabis sativa, and moringin, an isothiocyanate obtained from myrosinase-mediated hydrolysis of the glucosinolate precursor glucomoringin present in Moringa oleifera seeds, are well known antioxidants also endowed with anti-inflammatory activity. This is due to a covalent-based mechanism for ITC, while non-covalent interactions underlie the activity of CBD. Since these two mechanisms are distinct, and the molecular endpoints are potentially complementary, we investigated in a comparative way the protective effect of these compounds alone or in combination on lipopolysaccharide-stimulated murine macrophages. Our results show that the cannabidiol (5־¼M) and moringin (5־¼M) combination outperformed the sing

In [25]:
data.Effect.value_counts()

positive    462
negative     84
NONE         54
SAFE         12
Name: Effect, dtype: int64

In [26]:
data_trainval, data_test = train_test_split(data, test_size=0.2, random_state=1)

In [27]:
data_train, data_val = train_test_split(data_trainval, test_size=0.2, random_state=1)

In [28]:
data_train.shape, data_val.shape, data_test.shape

((391, 28), (98, 28), (123, 28))

In [29]:
all_labels = sorted(set(data.Effect))
all_labels

['NONE', 'SAFE', 'negative', 'positive']

In [30]:
from datasets import Dataset, DatasetDict

In [31]:
dataset_dict = DatasetDict({
    k: Dataset.from_dict({
        'text1': v.text_lhs,
        'text2': v.text_rhs,
        'label': v.Effect.apply(lambda x: all_labels.index(x))
    })
    for k, v in zip(['train', 'val', 'test'], [data_train, data_val, data_test])
})

# The model

In [82]:
# a medical BERT from https://academic.oup.com/bioinformatics/article/36/4/1234/5566506
model_checkpoint = 'dmis-lab/biobert-base-cased-v1.2'
# a roberta-like model for longer texts, but without token_type_ids
# model_checkpoint = 'allenai/longformer-base-4096'
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=len(all_labels))
if torch.cuda.is_available():
    model.cuda()

Some weights of the model checkpoint at dmis-lab/biobert-base-cased-v1.2 were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.decoder.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification we

In [83]:
from transformers import DataCollatorWithPadding
data_collator = DataCollatorWithPadding(tokenizer)

In [84]:
tokenizer('hello', 'world')

{'input_ids': [101, 19082, 102, 1362, 102], 'token_type_ids': [0, 0, 0, 1, 1], 'attention_mask': [1, 1, 1, 1, 1]}

In [85]:
def tokenize(x):
    return tokenizer(x["text1"], x['text2'], truncation=True, padding=True)

dataset_dict_tokenized = dataset_dict.map(tokenize, batched=True)

pd.Series(
    [len([tok for tok in text if tok != tokenizer.pad_token_id]) for text in dataset_dict_tokenized['train']['input_ids']]
).quantile([0.5, 0.8, 0.9, 0.95, 0.99, 1])

  0%|          | 0/1 [00:00<?, ?ba/s]

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

0.50     466.0
0.80     562.0
0.90     604.0
0.95     653.5
0.99     912.5
1.00    1104.0
dtype: float64

Most of the texts fit into 512 tokens, but for longer texts we might need to fine tune a longformer instead of a BERT

In [86]:
def tokenize(x):
    return tokenizer(x["text1"], x['text2'], truncation=True, padding=True, max_length=model.config.max_position_embeddings)

dataset_dict_tokenized = dataset_dict.map(tokenize, batched=True)

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

In [87]:
from torch.utils.data import DataLoader

In [88]:
batch_size = 16  # maximum size for colab 

loaders = {
    k : DataLoader(
        v.remove_columns(['text1', 'text2']), 
        batch_size=batch_size, 
        drop_last=False, 
        shuffle=True, 
        num_workers=1, 
        collate_fn=data_collator
    )
    for k, v in dataset_dict_tokenized.items()
}

In [121]:
from tqdm.auto import tqdm, trange
import numpy as np
from sklearn.metrics import precision_recall_fscore_support, roc_auc_score

In [128]:
def batched_predict(model, dev_dataloader):
    preds = []
    facts = []

    for batch in tqdm(dev_dataloader):
        facts.append(batch.labels.cpu().numpy())
        batch = batch.to(model.device)
        with torch.no_grad():
            pr = model(input_ids=batch.input_ids, attention_mask=batch.attention_mask, token_type_ids=batch.token_type_ids)
        preds.append(torch.softmax(pr.logits, -1).cpu().numpy())
    facts = np.concatenate(facts)
    preds = np.concatenate(preds)
    return facts, preds

In [90]:
def evaluate_model(model, dev_dataloader):
    facts, preds = batched_predict(model, dev_dataloader)

    p, r, f, s = precision_recall_fscore_support(facts, preds.argmax(axis=1), zero_division=0)
    results = {f'recall_{k}': v for k, v in zip(all_labels, r)}
    results['macro_recall'] = r.mean()
    results['macro_f'] = f.mean()
    results['accuracy'] = (facts == preds.argmax(axis=1)).mean()

    return results

model.eval()
print(evaluate_model(model, loaders['val']))

  0%|          | 0/7 [00:00<?, ?it/s]

{'recall_NONE': 0.75, 'recall_SAFE': 0.6666666666666666, 'recall_negative': 0.0, 'recall_positive': 0.0, 'macro_recall': 0.35416666666666663, 'macro_f': 0.07724137931034483, 'accuracy': 0.08163265306122448}


In [75]:
import gc

def cleanup():
    gc.collect()
    torch.cuda.empty_cache()
    
cleanup()

Замораживаем трансформерные слои, кроме последних трёх. Вроде как это может помочь меньше переобучаться. 

In [112]:
model.requires_grad_(True)
print(len([p for p in model.parameters() if p.requires_grad]))
model.bert.requires_grad_(False);
print(len([p for p in model.parameters() if p.requires_grad]))
model.bert.pooler.requires_grad_(True);
print(len([p for p in model.parameters() if p.requires_grad]))
model.bert.embeddings.requires_grad_(True)
print(len([p for p in model.parameters() if p.requires_grad]))
for layer in model.bert.encoder.layer[-3:]:
    layer.requires_grad_(True);
print(len([p for p in model.parameters() if p.requires_grad]))

201
2
4
9
57


In [113]:
optimizer = torch.optim.Adam(params=[p for p in model.parameters() if p.requires_grad], lr=1e-4)

In [114]:
gradient_accumulation_steps = max(1, int(128 / batch_size))
print(gradient_accumulation_steps)
window = 100
cleanup_step = 50
report_step = 500

8


In [115]:
for batch in loaders['train']:
    break
print(batch.keys())

dict_keys(['attention_mask', 'input_ids', 'token_type_ids', 'labels'])


In [116]:
ewm_loss = 0
model.train()
cleanup()

for epoch in trange(15):
    tq = tqdm(loaders['train'])
    
    for i, batch in enumerate(tq):
        
        if i % report_step == 0:
            model.eval()
            eval_loss = evaluate_model(model, loaders['val'])
            model.train()
            print(f'epoch {epoch}, step {i}: train loss: {ewm_loss:4.4f}  val auc: {eval_loss}')
            cleanup()
        
        try:
            batch = batch.to(model.device)
            output = model(**batch)
            loss = output.loss
            loss.backward()
        except RuntimeError as e:
            print('error on step', i, e)
            loss = None
            cleanup()
            continue

        if i and i % gradient_accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()
            
        if i % cleanup_step == 0:
            cleanup()

        w = 1 / min(i+1, window)
        ewm_loss = ewm_loss * (1-w) + loss.item() * w
        tq.set_description(f'loss: {ewm_loss:4.4f}')


  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

epoch 0, step 0: train loss: 0.0000  val auc: {'recall_NONE': 0.75, 'recall_SAFE': 0.6666666666666666, 'recall_negative': 0.0, 'recall_positive': 0.0, 'macro_recall': 0.35416666666666663, 'macro_f': 0.07724137931034483, 'accuracy': 0.08163265306122448}


  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

epoch 1, step 0: train loss: 1.4109  val auc: {'recall_NONE': 0.0, 'recall_SAFE': 0.0, 'recall_negative': 0.0, 'recall_positive': 1.0, 'macro_recall': 0.25, 'macro_f': 0.20833333333333331, 'accuracy': 0.7142857142857143}


  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

epoch 2, step 0: train loss: 0.8276  val auc: {'recall_NONE': 0.0, 'recall_SAFE': 0.0, 'recall_negative': 0.0, 'recall_positive': 1.0, 'macro_recall': 0.25, 'macro_f': 0.20833333333333331, 'accuracy': 0.7142857142857143}


  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

epoch 3, step 0: train loss: 0.8083  val auc: {'recall_NONE': 0.0, 'recall_SAFE': 0.0, 'recall_negative': 0.0, 'recall_positive': 1.0, 'macro_recall': 0.25, 'macro_f': 0.20833333333333331, 'accuracy': 0.7142857142857143}


  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

epoch 4, step 0: train loss: 0.7804  val auc: {'recall_NONE': 0.0, 'recall_SAFE': 0.0, 'recall_negative': 0.0, 'recall_positive': 1.0, 'macro_recall': 0.25, 'macro_f': 0.20833333333333331, 'accuracy': 0.7142857142857143}


  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

epoch 5, step 0: train loss: 0.7796  val auc: {'recall_NONE': 0.0, 'recall_SAFE': 0.0, 'recall_negative': 0.0, 'recall_positive': 1.0, 'macro_recall': 0.25, 'macro_f': 0.20833333333333331, 'accuracy': 0.7142857142857143}


  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

epoch 6, step 0: train loss: 0.7439  val auc: {'recall_NONE': 0.0, 'recall_SAFE': 0.0, 'recall_negative': 0.0, 'recall_positive': 1.0, 'macro_recall': 0.25, 'macro_f': 0.20833333333333331, 'accuracy': 0.7142857142857143}


  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

epoch 7, step 0: train loss: 0.6991  val auc: {'recall_NONE': 0.0, 'recall_SAFE': 0.0, 'recall_negative': 0.0, 'recall_positive': 1.0, 'macro_recall': 0.25, 'macro_f': 0.20833333333333331, 'accuracy': 0.7142857142857143}


  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

epoch 8, step 0: train loss: 0.6601  val auc: {'recall_NONE': 0.0, 'recall_SAFE': 0.0, 'recall_negative': 0.23529411764705882, 'recall_positive': 0.9714285714285714, 'macro_recall': 0.30168067226890755, 'macro_f': 0.296833064949007, 'accuracy': 0.7346938775510204}


  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

epoch 9, step 0: train loss: 0.5857  val auc: {'recall_NONE': 0.0, 'recall_SAFE': 0.0, 'recall_negative': 0.5882352941176471, 'recall_positive': 0.8714285714285714, 'macro_recall': 0.36491596638655466, 'macro_f': 0.34121621621621623, 'accuracy': 0.7244897959183674}


  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

epoch 10, step 0: train loss: 0.5583  val auc: {'recall_NONE': 0.0, 'recall_SAFE': 0.0, 'recall_negative': 0.17647058823529413, 'recall_positive': 1.0, 'macro_recall': 0.29411764705882354, 'macro_f': 0.28841463414634144, 'accuracy': 0.7448979591836735}


  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

epoch 11, step 0: train loss: 0.5138  val auc: {'recall_NONE': 0.0, 'recall_SAFE': 0.0, 'recall_negative': 0.5882352941176471, 'recall_positive': 0.9428571428571428, 'macro_recall': 0.38277310924369745, 'macro_f': 0.37193627450980393, 'accuracy': 0.7755102040816326}


  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

epoch 12, step 0: train loss: 0.4688  val auc: {'recall_NONE': 0.0, 'recall_SAFE': 0.0, 'recall_negative': 0.8823529411764706, 'recall_positive': 0.8571428571428571, 'macro_recall': 0.43487394957983194, 'macro_f': 0.37683823529411764, 'accuracy': 0.7653061224489796}


  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

epoch 13, step 0: train loss: 0.4526  val auc: {'recall_NONE': 0.0, 'recall_SAFE': 0.0, 'recall_negative': 0.5294117647058824, 'recall_positive': 0.9714285714285714, 'macro_recall': 0.37521008403361344, 'macro_f': 0.3800691244239631, 'accuracy': 0.7857142857142857}


  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

epoch 14, step 0: train loss: 0.4003  val auc: {'recall_NONE': 0.25, 'recall_SAFE': 0.0, 'recall_negative': 0.5882352941176471, 'recall_positive': 0.9285714285714286, 'macro_recall': 0.4417016806722689, 'macro_f': 0.44288793103448276, 'accuracy': 0.7857142857142857}


Полная модель за 15 эпох c шагом 1e-5 не успевает запомнить обучающую выборку,  

train data:
```
{'recall_NONE': 0.0, 'recall_SAFE': 0.875, 'recall_negative': 0.9622641509433962, 'recall_positive': 0.9863481228668942, 'macro_recall': 0.7059030684525727, 'macro_f': 0.6412634519138137, 'accuracy': 0.887468030690537}
```

validation data:
```
{'recall_NONE': 0.0, 'recall_SAFE': 0.3333333333333333, 'recall_negative': 0.5294117647058824, 'recall_positive': 0.8857142857142857, 'macro_recall': 0.43711484593837535, 'macro_f': 0.4322141560798548, 'accuracy': 0.7346938775510204}
```

Если увеличить скорость обучения до 1e-4, мы только оверфитимся:

```
train
{'recall_NONE': 0.9459459459459459, 'recall_SAFE': 0.75, 'recall_negative': 0.9433962264150944, 'recall_positive': 0.9897610921501706, 'macro_recall': 0.9072758161278027, 'macro_f': 0.9287041804696744, 'accuracy': 0.9744245524296675}
val
'recall_NONE': 0.125, 'recall_SAFE': 0.0, 'recall_negative': 0.47058823529411764, 'recall_positive': 0.8428571428571429, 'macro_recall': 0.3596113445378151, 'macro_f': 0.35904095904095906, 'accuracy': 0.6938775510204082}
```

In [117]:
model.eval()
train_loss = evaluate_model(model, loaders['train'])
print(train_loss)

  0%|          | 0/25 [00:00<?, ?it/s]

{'recall_NONE': 0.7027027027027027, 'recall_SAFE': 0.0, 'recall_negative': 0.9056603773584906, 'recall_positive': 0.962457337883959, 'macro_recall': 0.6427051044862881, 'macro_f': 0.6113919741028074, 'accuracy': 0.9104859335038363}


In [118]:
model.eval()
eval_loss = evaluate_model(model, loaders['val'])
print(eval_loss)

  0%|          | 0/7 [00:00<?, ?it/s]

{'recall_NONE': 0.25, 'recall_SAFE': 0.0, 'recall_negative': 0.6470588235294118, 'recall_positive': 0.9428571428571428, 'macro_recall': 0.4599789915966387, 'macro_f': 0.4588675213675214, 'accuracy': 0.8061224489795918}


Поучим ещё 15 эпох

In [119]:
ewm_loss = 0
model.train()
cleanup()

for epoch in trange(15):
    tq = tqdm(loaders['train'])
    
    for i, batch in enumerate(tq):
        
        if i % report_step == 0:
            model.eval()
            eval_loss = evaluate_model(model, loaders['val'])
            model.train()
            print(f'epoch {epoch}, step {i}: train loss: {ewm_loss:4.4f}  val auc: {eval_loss}')
            cleanup()
        
        try:
            batch = batch.to(model.device)
            output = model(**batch)
            loss = output.loss
            loss.backward()
        except RuntimeError as e:
            print('error on step', i, e)
            loss = None
            cleanup()
            continue

        if i and i % gradient_accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()
            
        if i % cleanup_step == 0:
            cleanup()

        w = 1 / min(i+1, window)
        ewm_loss = ewm_loss * (1-w) + loss.item() * w
        tq.set_description(f'loss: {ewm_loss:4.4f}')


  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

epoch 0, step 0: train loss: 0.0000  val auc: {'recall_NONE': 0.25, 'recall_SAFE': 0.0, 'recall_negative': 0.6470588235294118, 'recall_positive': 0.9428571428571428, 'macro_recall': 0.4599789915966387, 'macro_f': 0.4588675213675214, 'accuracy': 0.8061224489795918}


  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

epoch 1, step 0: train loss: 0.2604  val auc: {'recall_NONE': 0.0, 'recall_SAFE': 0.3333333333333333, 'recall_negative': 0.5294117647058824, 'recall_positive': 0.9857142857142858, 'macro_recall': 0.4621148459383754, 'macro_f': 0.5097402597402597, 'accuracy': 0.8061224489795918}


  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

epoch 2, step 0: train loss: 0.2486  val auc: {'recall_NONE': 0.375, 'recall_SAFE': 0.3333333333333333, 'recall_negative': 0.5882352941176471, 'recall_positive': 0.8857142857142857, 'macro_recall': 0.5455707282913165, 'macro_f': 0.5800555244494214, 'accuracy': 0.7755102040816326}


  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

epoch 3, step 0: train loss: 0.2201  val auc: {'recall_NONE': 0.25, 'recall_SAFE': 0.3333333333333333, 'recall_negative': 0.5882352941176471, 'recall_positive': 0.9571428571428572, 'macro_recall': 0.5321778711484594, 'macro_f': 0.5709750072653298, 'accuracy': 0.8163265306122449}


  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

epoch 4, step 0: train loss: 0.1543  val auc: {'recall_NONE': 0.0, 'recall_SAFE': 0.3333333333333333, 'recall_negative': 0.47058823529411764, 'recall_positive': 0.9714285714285714, 'macro_recall': 0.4438375350140056, 'macro_f': 0.4616152450090744, 'accuracy': 0.7857142857142857}


  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

epoch 5, step 0: train loss: 0.1449  val auc: {'recall_NONE': 0.25, 'recall_SAFE': 0.3333333333333333, 'recall_negative': 0.5882352941176471, 'recall_positive': 0.9571428571428572, 'macro_recall': 0.5321778711484594, 'macro_f': 0.5709750072653298, 'accuracy': 0.8163265306122449}


  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

epoch 6, step 0: train loss: 0.1148  val auc: {'recall_NONE': 0.25, 'recall_SAFE': 0.3333333333333333, 'recall_negative': 0.5882352941176471, 'recall_positive': 0.9285714285714286, 'macro_recall': 0.5250350140056022, 'macro_f': 0.555321633735244, 'accuracy': 0.7959183673469388}


  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

epoch 7, step 0: train loss: 0.1086  val auc: {'recall_NONE': 0.25, 'recall_SAFE': 0.3333333333333333, 'recall_negative': 0.5882352941176471, 'recall_positive': 0.9285714285714286, 'macro_recall': 0.5250350140056022, 'macro_f': 0.5557758166491044, 'accuracy': 0.7959183673469388}


  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

epoch 8, step 0: train loss: 0.0767  val auc: {'recall_NONE': 0.25, 'recall_SAFE': 0.3333333333333333, 'recall_negative': 0.5882352941176471, 'recall_positive': 0.9428571428571428, 'macro_recall': 0.5286064425770307, 'macro_f': 0.5640731292517007, 'accuracy': 0.8061224489795918}


  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

epoch 9, step 0: train loss: 0.0643  val auc: {'recall_NONE': 0.25, 'recall_SAFE': 0.3333333333333333, 'recall_negative': 0.5882352941176471, 'recall_positive': 0.9571428571428572, 'macro_recall': 0.5321778711484594, 'macro_f': 0.5709750072653298, 'accuracy': 0.8163265306122449}


  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

epoch 10, step 0: train loss: 0.0552  val auc: {'recall_NONE': 0.25, 'recall_SAFE': 0.3333333333333333, 'recall_negative': 0.5882352941176471, 'recall_positive': 0.9571428571428572, 'macro_recall': 0.5321778711484594, 'macro_f': 0.5709750072653298, 'accuracy': 0.8163265306122449}


  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

epoch 11, step 0: train loss: 0.0562  val auc: {'recall_NONE': 0.25, 'recall_SAFE': 0.3333333333333333, 'recall_negative': 0.5882352941176471, 'recall_positive': 0.9571428571428572, 'macro_recall': 0.5321778711484594, 'macro_f': 0.5709750072653298, 'accuracy': 0.8163265306122449}


  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

epoch 12, step 0: train loss: 0.0457  val auc: {'recall_NONE': 0.25, 'recall_SAFE': 0.3333333333333333, 'recall_negative': 0.5882352941176471, 'recall_positive': 0.9428571428571428, 'macro_recall': 0.5286064425770307, 'macro_f': 0.5640731292517007, 'accuracy': 0.8061224489795918}


  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

epoch 13, step 0: train loss: 0.0367  val auc: {'recall_NONE': 0.25, 'recall_SAFE': 0.3333333333333333, 'recall_negative': 0.5882352941176471, 'recall_positive': 0.9428571428571428, 'macro_recall': 0.5286064425770307, 'macro_f': 0.5608758821087588, 'accuracy': 0.8061224489795918}


  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

epoch 14, step 0: train loss: 0.0397  val auc: {'recall_NONE': 0.25, 'recall_SAFE': 0.3333333333333333, 'recall_negative': 0.5882352941176471, 'recall_positive': 0.9571428571428572, 'macro_recall': 0.5321778711484594, 'macro_f': 0.5709750072653298, 'accuracy': 0.8163265306122449}


После ещё 15 эпох можешь уже неплохо выучивает обучающую выборку, но всё ещё очень плоха на валидационной

```
{'recall_NONE': 0.9459459459459459, 'recall_SAFE': 1.0, 'recall_negative': 0.9056603773584906, 'recall_positive': 1.0, 'macro_recall': 0.9629015808261092, 'macro_f': 0.9478770967414829, 'accuracy': 0.9820971867007673}

{'recall_NONE': 0.25, 'recall_SAFE': 0.3333333333333333, 'recall_negative': 0.5882352941176471, 'recall_positive': 0.9571428571428572, 'macro_recall': 0.5321778711484594, 'macro_f': 0.5709750072653298, 'accuracy': 0.8163265306122449}
```

In [120]:
model.eval()
train_loss = evaluate_model(model, loaders['train'])
print(train_loss)
eval_loss = evaluate_model(model, loaders['val'])
print(eval_loss)

  0%|          | 0/25 [00:00<?, ?it/s]

{'recall_NONE': 0.9459459459459459, 'recall_SAFE': 1.0, 'recall_negative': 0.9056603773584906, 'recall_positive': 1.0, 'macro_recall': 0.9629015808261092, 'macro_f': 0.9478770967414829, 'accuracy': 0.9820971867007673}


  0%|          | 0/7 [00:00<?, ?it/s]

{'recall_NONE': 0.25, 'recall_SAFE': 0.3333333333333333, 'recall_negative': 0.5882352941176471, 'recall_positive': 0.9571428571428572, 'macro_recall': 0.5321778711484594, 'macro_f': 0.5709750072653298, 'accuracy': 0.8163265306122449}


In [129]:
facts, preds = batched_predict(model, loaders['val'])

  0%|          | 0/7 [00:00<?, ?it/s]

In [131]:
pd.DataFrame(preds).describe()

Unnamed: 0,0,1,2,3
count,98.0,98.0,98.0,98.0
mean,0.043802,0.020825,0.146522,0.788851
std,0.190159,0.13644,0.344236,0.39577
min,0.000117,7.8e-05,6.5e-05,0.000624
25%,0.00038,0.000155,0.000137,0.927233
50%,0.000852,0.000232,0.0003,0.998497
75%,0.002465,0.00073,0.013394,0.999304
max,0.988854,0.96599,0.994371,0.999607


In [132]:
for i in range(4):
    print(roc_auc_score(facts==i, preds[:, i]))    

0.7694444444444444
0.7894736842105263
0.860566448801743
0.8525510204081633


Previous results are (on another train/test split) the following:

so the current model is not so bad (although it can be much, much better).

In [136]:
print('positive recall', 80 / (80+3+7))
print('negative recall', 19 / (12+19+7))
print('none     recall', 19 / (7 +4+19))

print('accuracy', (80 + 19 + 19) / ((80+3+7) + (12+19+7) + (7 +4+19)))

positive recall 0.8888888888888888
negative recall 0.5
none     recall 0.6333333333333333
accuracy 0.7468354430379747
