In [24]:
from src.models.bert import BertClassifier
from src.dataloader.dataloading import TrainDataset, TestDataset
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import DataLoader


DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device: ", DEVICE)

Using device:  cuda


Load model

In [33]:
model = BertClassifier()
model.load_state_dict(torch.load("model_weights/bert_clf_augmented_data.pth"))
model.to(DEVICE);
test_dataset = TestDataset('./data/test_shuffle.txt')
test_dataloader = DataLoader(test_dataset, batch_size=128)
train_dataset = TrainDataset('./data/augmented.json')
labels = train_dataset.labels

Generate pseudo-labels for the test set

In [34]:
probs = []
model.eval()
with torch.no_grad():
    for input_ids, attention_mask, token_type_ids in test_dataloader:
            input_ids, attention_mask, token_type_ids = input_ids.to(DEVICE), attention_mask.to(DEVICE), token_type_ids.to(DEVICE)
            outputs = model(input_ids, attention_mask, token_type_ids)
            # print(outputs)
            probs.append(outputs.cpu())
probs = torch.cat(probs, dim=0)
confidence = torch.max(probs, dim=1).values
preds = torch.argmax(probs, dim=1)
preds = [labels[i.item()] for i in preds]
texts = test_dataset.sentences
confidence = confidence.numpy()

pred_df = pd.DataFrame({'label': preds, 'confidence': confidence})

In [35]:
# pred_df.head()
thresh = 0.99
confident_df = pred_df[pred_df['confidence'] > thresh]
# percantages of confident predictions for each label
print(confident_df['label'].value_counts() / pred_df['label'].value_counts())
print(confident_df['label'].value_counts())
print(confident_df.index)

label
Education        0.821429
Entertainment    0.489362
Environment      0.578125
Fashion          0.822917
Finance          0.697674
Food             0.285714
Health           0.773333
Politics         0.839080
Science          0.567164
Sports           0.423913
Technology       0.514286
Travel           0.686275
Name: count, dtype: float64
label
Health           116
Education         92
Fashion           79
Environment       74
Politics          73
Travel            70
Finance           60
Entertainment     46
Sports            39
Science           38
Technology        36
Food              16
Name: count, dtype: int64
Index([   0,    1,    2,    4,    5,    6,    7,    8,   10,   11,
       ...
       1120, 1121, 1122, 1125, 1127, 1128, 1133, 1134, 1136, 1138],
      dtype='int64', length=739)


Train data augmentation with pseudo-labels

In [36]:
import json

with open('data/train.json') as f:
    train_data = json.load(f)
# print(train_data)
for i in confident_df.index:
    label = confident_df.loc[i, 'label']
    text = texts[i]
    train_data[label].append(text)

with open('data/augmented_semi.json', 'w') as f:
    json.dump(train_data, f)