In [18]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np
import pandas as pd
from sklearn.metrics import accuracy_score

from SWEM import Model

import t_utils as utils
from t_utils import jupyter_args, DataBuilder

In [19]:
torch.manual_seed(42)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
data_path = '../data/labeled.csv'
d_train, d_dev, tokenizer = utils.get_data(data_path, 'title')
vs = len(tokenizer) + 1

In [21]:
unlabeled = pd.read_csv('../data/unlabeled.csv')
d_train['labeled'] = True
unlabeled['labeled'] = False
d_train = pd.concat([d_train, unlabeled], ignore_index=True)
d_train = d_train.sample(frac=1, random_state=42)
utils.dict_build(d_train.title, True, '../resource/char2idx.pkl')

#d_train = d_train.head(80)

d_train['target'] = d_train.apply(lambda x: x.target if x.labeled == True else -1, axis=1)
N = d_train.shape[0]
M = d_train[d_train.labeled == True].shape[0]

In [4]:
args = jupyter_args(embed_dim=100,
                    seq_length=35,
                    hidden_state=256,
                    num_classes=14,
                    vocab_size=vs,
                    batch_size=100,
                    learning_rate=3e-3,
                    epoch_num=20)

In [5]:
trainbulider = DataBuilder(d_train, 'title', 'target', tokenizer, args.seq_length)
trainloader = DataLoader(trainbulider, args.batch_size, shuffle=False)

devbulider = DataBuilder(d_dev, 'title', 'target', tokenizer, args.seq_length)
devloader = DataLoader(devbulider, args.batch_size, shuffle=False)

model = Model(args).to(device)
optimizer = torch.optim.Adam(model.parameters(), args.learning_rate)

labeled_lossfunc = nn.CrossEntropyLoss(reduce=True).to(device)
unlabeled_lossfunc = nn.MSELoss(reduce=True).to(device)

In [15]:
def w_func(epoch, max_epoch, M , N, w_max=30):
    if epoch == 0:
        return 0.0
    if epoch > max_epoch:
        epoch = max_spoch

    T = epoch / max_epoch
    w_max = w_max * M / N
    
    return w_max * np.exp(-5 * (1 - T) ** 2)

In [25]:
alpha = 0.6
z = torch.zeros((N, args.num_classes), requires_grad=False).to(device)
z_wavy = torch.zeros((N, args.num_classes), requires_grad=False).to(device)

for epoch in range(args.epoch_num):
    
    y_pred = []
    y_true = []
    epoch_loss = 0
    batch_count = len(trainloader)
    w = torch.FloatTensor([w_func(epoch, 8, M , N, w_max=30)]).to(device)
    
    model.train()
    t = epoch + 1
    outputs = torch.zeros((N, args.num_classes)).to(device)
    
    for i, (x, mask, target, label_mask, index) in enumerate(trainloader):
        
        bs = x.size(0)
        x = x.to(device)
        mask = mask.to(device)
        index = index.to(device)

        optimizer.zero_grad()
        logits = model(x, mask)
        outputs[i * bs : (i + 1) * bs] = logits.data.clone()
        
        target.requires_grad = False
        target = target.view(-1).to(device)
        label_mask = label_mask.byte().to(device)
        labeled_target = target.masked_select(label_mask)

        label_mask = label_mask.unsqueeze(1).expand(bs, args.num_classes)
        labeled_logits = logits.masked_select(label_mask).view(-1, args.num_classes)
        
        if labeled_target.size(0) == 0:
            labeled_loss = 0
        else:
            labeled_loss = labeled_lossfunc(labeled_logits, labeled_target)
            
            predict = labeled_logits.max(1)[1]
            y_true += labeled_target.tolist()
            y_pred += predict.tolist()
            epoch_loss += labeled_loss.item()

        z_wavy_batch = z_wavy.index_select(0, index)
        unlabeled_loss = unlabeled_lossfunc(logits, z_wavy_batch)

        loss = labeled_loss + w * unlabeled_loss
        loss.backward()
        optimizer.step()
        
    print(outputs.size())
    z = alpha * z + (1 - alpha) * outputs
    z_wavy = z / (1 - alpha ** t)
    
    
    model.eval()
    dev_true, dev_pred, dev_loss = utils.eval(model, devloader, labeled_lossfunc)
    dev_acc = accuracy_score(dev_true, dev_pred)

    epoch = (epoch + 1)
    train_loss = epoch_loss / batch_count
    train_acc = accuracy_score(y_true, y_pred)

    print('''------------------------------------
Epoch %d
Train_loss: %.4f  Dev_loss: %.4f
Train_acc: %.4f  Dev_acc: %.4f
------------------------------------''' % (epoch, train_loss, dev_loss, train_acc, dev_acc))

#torch.save(model.state_dict(), '../model/swem_bs15_lr5e-3_total.pkl')

torch.Size([96000, 14])
------------------------------------
Epoch 1
Train_loss: 0.0060  Dev_loss: 1.2193
Train_acc: 0.9997  Dev_acc: 0.7473
------------------------------------


KeyboardInterrupt: 