In [2]:
import torch
import torchshard as ts
from torch import nn, optim
import pytorch_lightning as pl
from torch.nn import functional as F
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, TensorDataset
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import LearningRateMonitor
from sklearn.utils.class_weight import compute_class_weight

In [3]:
X = torch.load('X.pt')
X = X.view(X.shape[0], -1)
y = torch.load('y.pt').long()
print(X.shape)

torch.Size([15483, 24576])


In [3]:
BATCH_SIZE = 32
TEST_SIZE = 0.3
DROPOUT = 0.5
LEARNING_RATE = 3e-4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
CLASS_WEIGHT = compute_class_weight('balanced', classes=list(range(7)), y=y.numpy())

In [4]:
class NN_Indiv(pl.LightningModule):
    
    def __init__(self, dropout, lr):
        super().__init__()
        self.fc = nn.Sequential(
            ts.nn.ParallelLinear(3072, 192),
            nn.GELU(),
            nn.BatchNorm1d(192),
            nn.Dropout(dropout),
            ts.nn.ParallelLinear(192, 192),
        )
    
    def forward(self, x):
        return self.fc(x)

In [5]:
class NN(pl.LightningModule):
    
    def __init__(self, device, train_dataset, val_dataset, batch_size, lr, dropout, class_weight):
        super().__init__()
        self.comps = [NN_Indiv(dropout, lr).to(device) for _ in range(8)]
        self.en_fc = nn.Sequential(
            nn.GELU(),
            nn.BatchNorm1d(768),
            nn.Dropout(dropout),
            ts.nn.ParallelLinear(768, 384)
        )
        self.zh_fc = nn.Sequential(
            nn.GELU(),
            nn.BatchNorm1d(768),
            nn.Dropout(dropout),
            ts.nn.ParallelLinear(768, 384)
        )
        self.fc = nn.Sequential(
            nn.GELU(),
            nn.BatchNorm1d(768),
            nn.Dropout(dropout),
            ts.nn.ParallelLinear(768, 6)
        )
        # self.criterion = nn.CrossEntropyLoss(weight=class_weight)
        self.pos_weights = torch.Tensor([1, .8, .6, .4, .2, 0]).to(device)
        self.pos_weights.requires_grad = False
        self.batch_size = batch_size
        self.lr = lr
    
    def forward(self, x):
        en, zh = torch.tensor_split(
            torch.cat(
                [m(t) for m, t in zip(self.comps, torch.tensor_split(x, 8, dim=1))], dim=1),
            2, dim=1)
        xx = torch.cat((self.en_fc(en), self.zh_fc(zh)), dim=1)
        return self.fc(xx)
    
    def loss_func(self, logits, y):
        weighted_pos = (F.softmax(logits, dim=1) * self.pos_weights).sum(axis=1)
        y_pos = 1 - y / 5
        pos_dis = ((weighted_pos - y_pos) ** 2).sum()
        return pos_dis
    
    def training_step(self, batch, idx):
        X, y = batch
        logits = self(X)
        loss = self.loss_func(logits, y)
        self.log(f'train_pos_dis', loss)
        return loss
    
    def validation_step(self, batch, idx):
        X, y = batch
        logits = self(X)
        loss = self.loss_func(logits, y)
        self.log(f'val_pos_dis', loss)
        return loss
    
    def sort_notis(self, text):
        with torch.no_grad():
            X = torch.stack([torch.cat((
                torch.cat(en_model(**en_tokenizer(row, return_tensors='pt', padding=True, truncation=True), output_hidden_states=True)[2][-4: ])[:, 0].detach(),
                torch.cat(zh_model(**zh_tokenizer(row, return_tensors='pt', padding=True, truncation=True), output_hidden_states=True)[2][-4: ])[:, 0].detach(),
                    )) for row in tqdm(text)]).to(self.device)
            X = X.view(X.shape[0], -1)
            logits = self(X)
            _, y_pred = torch.max(logits.data, axis=1)
            print(F.softmax(logits, dim=1))
            weighted_pos = (F.softmax(logits, dim=1) * self.pos_weights).sum(axis=1).cpu().numpy()
            text_score = list(zip(text, weighted_pos))
            text_score.sort(key=lambda x: -x[1])
        return text_score
    
    @property
    def num_training_steps(self) -> int:
        """Total training steps inferred from datamodule and devices."""
        if self.trainer.max_steps:
            return self.trainer.max_steps

        limit_batches = self.trainer.limit_train_batches
        batches = len(self.train_dataloader())
        batches = min(batches, limit_batches) if isinstance(limit_batches, int) else int(limit_batches * batches)     

        num_devices = max(1, self.trainer.num_gpus, self.trainer.num_processes)
        if self.trainer.tpu_cores:
            num_devices = max(num_devices, self.trainer.tpu_cores)

        effective_accum = self.trainer.accumulate_grad_batches * num_devices
        return (batches // effective_accum) * self.trainer.max_epochs
    
    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=self.lr)
        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, self.num_training_steps)
        return [optimizer], [scheduler]
    
    def train_dataloader(self):
        return DataLoader(train_dataset, batch_size=self.batch_size, num_workers=8, shuffle=True)
    
    def val_dataloader(self):
        return DataLoader(val_dataset, batch_size=self.batch_size, num_workers=8)

In [6]:
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=TEST_SIZE)
train_dataset = TensorDataset(X_train, y_train)
val_dataset = TensorDataset(X_val, y_val)

model = NN(DEVICE, train_dataset, val_dataset, BATCH_SIZE, LEARNING_RATE, DROPOUT, CLASS_WEIGHT)

In [None]:
lr_monitor = LearningRateMonitor(logging_interval='step')

trainer = Trainer(
    #auto_lr_find=True,
    #auto_scale_batch_size="binsearch",
    callbacks=[lr_monitor],
    gpus=1,
    #logger=False,
    max_epochs=1500,
    profiler="simple",
    stochastic_weight_avg=True,
    track_grad_norm=2,
    weights_save_path="model.pt",
)

trainer.tune(model)
trainer.fit(model)

In [10]:
torch.save(model.state_dict(), 'model8.pt')
for i, m in enumerate(model.comps):
    torch.save(m.state_dict(), f'model{i}.pt')

In [11]:
model.load_state_dict(torch.load('model8.pt'))
model = model.to(DEVICE)
model.eval()
for i, m in enumerate(model.comps):
    model.comps[i].load_state_dict(torch.load(f'model{i}.pt'))
    model.comps[i] = model.comps[i].to(DEVICE)
    model.comps[i].eval()

In [33]:
import pandas as pd
from tqdm import tqdm
from pprint import pprint
from random import sample
from transformers import AutoTokenizer, AutoModel

en_tokenizer = AutoTokenizer.from_pretrained("bert-base-cased", output_hidden_states=True)
en_model = AutoModel.from_pretrained("bert-base-cased", output_hidden_states=True)
zh_tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese", output_hidden_states=True)
zh_model = AutoModel.from_pretrained("bert-base-chinese", output_hidden_states=True)

text = [
    ['Messenger', 'IM', '公司', '急']
]
pprint(model.sort_notis(text))

100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 12.42it/s]

tensor([[3.6560e-04, 7.8779e-01, 1.7674e-03, 1.1068e-01, 2.5234e-05, 9.9370e-02]],
       device='cuda:0')
[(['Messenger', 'IM', '公司', '急'], 0.67593586)]



