이 파일의 소스코드는 아래의 레포지토리를 참고했다.

#### Reference: https://github.com/zhongyuchen/few-shot-text-classification

In [1]:
import os
import torch
from model import FewShotInduction
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer
from glob import glob
from tqdm import tqdm
from torch import optim
from torch.nn.utils.rnn import pad_sequence
from criterion import Criterion

In [2]:
data_path = 'Amazon_few_shot'

In [3]:
# 반드시 do_lower_case=True로 해야 한다.
# bert-base-uncased는 영어 데이터를 소문자로 변환해서 학습한 모델이기 때문이다.
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)

### make dataset and dataloader

In [4]:
class AmazonDataset():
    def __init__(self, data_path, tokenizer, dtype):
        self.data_path = data_path
        self.tokenizer = tokenizer
        with open(f'{dtype}.list', 'r') as f:
            self.categories = [oneline.rstrip() for oneline in f]
        self.support_dataset = {}
        self.dataset = {}
        for category in tqdm(self.categories, desc='reading categories'):
            self.dataset[category] = {
                'neg': self.get_data(category, 'neg', dtype),
                'pos': self.get_data(category, 'pos', dtype)
            }
        
        if dtype == 'test' or dtype == 'dev':
            for category in tqdm(self.categories, desc='reading categories for support'):
                self.support_dataset[category] = {
                    'neg': self.get_data(category, 'neg', 'train'),
                    'pos': self.get_data(category, 'pos', 'train'),
                }
        
    def read_files(self, category, label, dtype):
        data = {
            'text': [],
            'label': []
        }
        for t in ['t2', 't4', 't5']:
            filename = f'{category}.{t}.{dtype}'
            with open(os.path.join(self.data_path, filename), 'r', encoding='utf-8') as f:
                for oneline in f:
                    oneline = oneline.rstrip()
                    text = oneline[:-2]
                    if int(oneline[-2:]) == 1 and label == 'pos':
                        tensor = self.tokenizer(text, return_tensors='pt')
                        data['text'].append(tensor['input_ids'][0])
                        data['label'].append(1)
                    elif int(oneline[-2:]) == -1 and label == 'neg':
                        tensor = self.tokenizer(text, return_tensors='pt')
                        data['text'].append(tensor['input_ids'][0])
                        data['label'].append(0)
        data['label'] = torch.tensor(data['label'])
        return data
    
    def get_data(self, category, label, dtype):
        data = self.read_files(category, label, dtype)
        return data

In [5]:
train_dataset = AmazonDataset(data_path, tokenizer, 'train')
dev_dataset = AmazonDataset(data_path, tokenizer, 'dev')
test_dataset = AmazonDataset(data_path, tokenizer, 'test')

reading categories:   0%|          | 0/14 [00:00<?, ?it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (625 > 512). Running this sequence through the model will result in indexing errors
reading categories: 100%|██████████| 14/14 [01:14<00:00,  5.35s/it]
reading categories: 100%|██████████| 5/5 [00:00<00:00, 11.93it/s]
reading categories for support: 100%|██████████| 5/5 [00:03<00:00,  1.43it/s]
reading categories: 100%|██████████| 4/4 [00:08<00:00,  2.11s/it]
reading categories for support: 100%|██████████| 4/4 [00:00<00:00, 36.36it/s]


In [6]:
def pad_text(a_text, b_text):
    a_text_len = a_text.shape[1]
    b_text_len = b_text.shape[1]

    if a_text_len > b_text_len:
        b_text = torch.cat([b_text, torch.zeros(b_text.shape[0], a_text_len-b_text_len).long()], dim=1)
    else:
        a_text = torch.cat([a_text, torch.zeros(a_text.shape[0], b_text_len-a_text_len).long()], dim=1)
        
    return a_text, b_text

In [7]:
class AmazonDataLoader():
    def __init__(self, dataset, batch_size, n_support):
        assert n_support % 2 == 0, 'n_support should be multiple of 2'
        self.dataset = dataset
        self.batch_size = batch_size
        self.n_support = n_support
        self.neg_idx = {k:0 for k in dataset.dataset}
        self.pos_idx = {k:0 for k in dataset.dataset}
        self.neg_len = {k:len(dataset.dataset[k]['neg']['text']) for k in dataset.dataset}
        self.pos_len = {k:len(dataset.dataset[k]['pos']['text']) for k in dataset.dataset}
        self.neg = {k:dataset.dataset[k]['neg'] for k in dataset.dataset}
        self.pos = {k:dataset.dataset[k]['pos'] for k in dataset.dataset}
        self.idx = 0
        self.categories = [k for k in dataset.dataset]
        
        # prepare for test dataset, support dataset should come from "*.train"
        self.neg_support_idx = {}
        self.pos_support_idx = {}
        self.neg_support_len = {}
        self.pos_support_len = {}
        if self.dataset.support_dataset:
            self.neg_support_idx = {k:0 for k in self.dataset.support_dataset}
            self.pos_support_idx = {k:0 for k in self.dataset.support_dataset}
            self.neg_support_len = {k:len(self.dataset.support_dataset[k]['neg']['text']) for k in self.dataset.support_dataset}
            self.pos_support_len = {k:len(self.dataset.support_dataset[k]['pos']['text']) for k in self.dataset.support_dataset}
        
    def get_batch(self):
        category = self.categories[self.idx % len(self.categories)]
        neg = self.neg[category]
        pos = self.pos[category]
        neg_start_idx = self.neg_idx[category] % self.neg_len[category]
        pos_start_idx = self.pos_idx[category] % self.pos_len[category]
        
        # prepare negative/positive dataset
        neg_text = neg['text'][neg_start_idx:neg_start_idx+(self.batch_size//2)]
        pos_text = pos['text'][pos_start_idx:pos_start_idx+(self.batch_size//2)]
        neg_label = neg['label'][neg_start_idx:neg_start_idx+(self.batch_size//2)]
        pos_label = pos['label'][pos_start_idx:pos_start_idx+(self.batch_size//2)]
        self.neg_idx[category] += (self.batch_size//2)
        self.pos_idx[category] += (self.batch_size//2)
        
        if len(neg_text) + len(pos_text) != self.batch_size:
            return self.get_batch()
            
        # padding text dataset
        neg_text = pad_sequence([n for n in neg_text], batch_first=True)
        pos_text = pad_sequence([p for p in pos_text], batch_first=True)
        neg_text, pos_text = pad_text(neg_text, pos_text)
            
        # prepare support/query text
        neg_support_text = neg_text[:self.n_support//2]
        pos_support_text = pos_text[:self.n_support//2]
        neg_query_text = neg_text[self.n_support//2:]
        pos_query_text = pos_text[self.n_support//2:]
        
        # prepare support/query label
        neg_support_label = neg_label[:self.n_support//2]
        pos_support_label = pos_label[:self.n_support//2]
        neg_query_label = neg_label[self.n_support//2:]
        pos_query_label = pos_label[self.n_support//2:]
        
        # merge support/query text
        support_text = torch.cat([neg_support_text, pos_support_text], dim=0)
        query_text = torch.cat([neg_query_text, pos_query_text], dim=0)
        
        # merge support/query label
        support_label = torch.cat([neg_support_label, pos_support_label], dim=0)
        query_label = torch.cat([neg_query_label, pos_query_label], dim=0)
        
        # make data and label
        data = torch.cat([support_text, query_text], dim=0)
        label = torch.cat([support_label, query_label], dim=0)
        
        # increase category index
        self.idx += 1
        return data, label
    
    def get_batch_test(self):
        assert self.dataset.support_dataset, 'support_dataset is empty'
        
        category = self.categories[self.idx % len(self.categories)]
        neg = self.neg[category]
        pos = self.pos[category]
        neg_query_start_idx = self.neg_idx[category] % self.neg_len[category]
        pos_query_start_idx = self.pos_idx[category] % self.pos_len[category]
        neg_support_start_idx = self.neg_support_idx[category] % self.neg_support_len[category]
        pos_support_start_idx = self.pos_support_idx[category] % self.pos_support_len[category]
        
        # prepare negative/positive support dataset from support_dataset
        category_suuport_dataset = self.dataset.support_dataset[category]
        neg_support_text = category_suuport_dataset['neg']['text'][neg_support_start_idx:neg_support_start_idx+self.n_support//2]
        pos_support_text = category_suuport_dataset['pos']['text'][pos_support_start_idx:pos_support_start_idx+self.n_support//2]
        neg_support_label = category_suuport_dataset['neg']['label'][neg_support_start_idx:neg_support_start_idx+self.n_support//2]
        pos_support_label = category_suuport_dataset['pos']['label'][pos_support_start_idx:pos_support_start_idx+self.n_support//2]
        self.neg_support_idx[category] += (self.n_support//2)
        self.pos_support_idx[category] += (self.n_support//2)
        
        # prepare negative/positive query dataset
        neg_query_text = neg['text'][neg_query_start_idx:neg_query_start_idx+(self.batch_size//2 - self.n_support//2)]
        pos_query_text = pos['text'][pos_query_start_idx:pos_query_start_idx+(self.batch_size//2 - self.n_support//2)]
        neg_query_label = neg['label'][neg_query_start_idx:neg_query_start_idx+(self.batch_size//2 - self.n_support//2)]
        pos_query_label = pos['label'][pos_query_start_idx:pos_query_start_idx+(self.batch_size//2 - self.n_support//2)]
        self.neg_idx[category] += (self.batch_size//2 - self.n_support//2)
        self.pos_idx[category] += (self.batch_size//2 - self.n_support//2)
        
        # padding support text dataset
        if self.n_support:
            neg_support_text = pad_sequence([n for n in neg_support_text], batch_first=True)
            pos_support_text = pad_sequence([n for n in pos_support_text], batch_first=True)
            neg_support_text, pos_support_text = pad_text(neg_support_text, pos_support_text)
        else:
            neg_support_text = torch.tensor([[]])
            pos_support_text = torch.tensor([[]])
            
        # padding text dataset
        neg_query_text = pad_sequence([n for n in neg_query_text], batch_first=True)
        pos_query_text = pad_sequence([p for p in pos_query_text], batch_first=True)
        neg_query_text, pos_query_text = pad_text(neg_query_text, pos_query_text)

        # concatenating support/query text dataset
        support_text = torch.cat([neg_support_text, pos_support_text], dim=0)
        query_text = torch.cat([neg_query_text, pos_query_text], dim=0)
        support_text, query_text = pad_text(support_text, query_text)

        # make final data and label
        if self.n_support:
            data = torch.cat([support_text, query_text], dim=0)
        else:
            data = query_text
        label = torch.cat([neg_support_label, pos_support_label, neg_query_label, pos_query_label], dim=0)
        return data, label

In [8]:
support = 5

In [9]:
train_dataloader = AmazonDataLoader(train_dataset, batch_size=64, n_support=support*2)
dev_dataloader = AmazonDataLoader(dev_dataset, batch_size=64, n_support=support*2)
test_dataloader = AmazonDataLoader(test_dataset, batch_size=64, n_support=support*2)

In [10]:
for i in range(10):
    d, l = train_dataloader.get_batch()
    print(d.shape, l.float().mean())

torch.Size([64, 149]) tensor(0.5000)
torch.Size([64, 460]) tensor(0.5000)
torch.Size([64, 254]) tensor(0.5000)
torch.Size([64, 262]) tensor(0.5000)
torch.Size([64, 1283]) tensor(0.5000)
torch.Size([64, 1658]) tensor(0.5000)
torch.Size([64, 613]) tensor(0.5000)
torch.Size([64, 359]) tensor(0.5000)
torch.Size([64, 530]) tensor(0.5000)
torch.Size([64, 602]) tensor(0.5000)


In [11]:
for i in range(10):
    d, l = dev_dataloader.get_batch_test()
    print(d.shape, l.float().mean())

torch.Size([64, 327]) tensor(0.5000)
torch.Size([55, 181]) tensor(0.5818)
torch.Size([64, 198]) tensor(0.5000)
torch.Size([46, 295]) tensor(0.6957)
torch.Size([64, 197]) tensor(0.5000)
torch.Size([64, 276]) tensor(0.5000)
torch.Size([55, 186]) tensor(0.5818)
torch.Size([64, 270]) tensor(0.5000)
torch.Size([26, 130]) tensor(0.4615)
torch.Size([64, 327]) tensor(0.5000)


In [12]:
for i in range(10):
    d, l = test_dataloader.get_batch_test()
    print(d.shape, l.float().mean())

torch.Size([64, 743]) tensor(0.5000)
torch.Size([64, 841]) tensor(0.5000)
torch.Size([64, 1386]) tensor(0.5000)
torch.Size([64, 706]) tensor(0.5000)
torch.Size([64, 1026]) tensor(0.5000)
torch.Size([64, 1126]) tensor(0.5000)
torch.Size([64, 1116]) tensor(0.5000)
torch.Size([64, 1333]) tensor(0.5000)
torch.Size([64, 568]) tensor(0.5000)
torch.Size([64, 570]) tensor(0.5000)


### define model

In [13]:
model = FewShotInduction(C=2,
                         S=support,
                         vocab_size=len(tokenizer),
                         embed_size=300,
                         hidden_size=128,
                         d_a=64,
                         iterations=3,
                         outsize=100)
model = model.cuda()

In [14]:
len(tokenizer)

30522

In [15]:
optimizer = optim.Adam(model.parameters(), lr=float(1e-4))

In [16]:
criterion = Criterion(way=2, shot=support)

In [17]:
def train(episode):
    model.train()
    data, target = train_dataloader.get_batch()
    data = data.cuda()
    target = target.cuda()
    optimizer.zero_grad()
    predict = model(data)
    loss, acc = criterion(predict, target)
    loss.backward()
    optimizer.step()
    return loss

In [18]:
def dev(episode):
    model.eval()
    correct = 0.
    count = 0.
    for i in range(100):
        data, target = dev_dataloader.get_batch_test()
        data = data.cuda()
        target = target.cuda()
        predict = model(data)
        _, acc = criterion(predict, target)
        amount = len(target) - support * 2
        correct += acc * amount
        count += amount
    acc = correct / count
    return acc

In [19]:
dev_interval = 100
best_acc = -1.0

In [20]:
tbar = tqdm(range(1, 10000))
for episode in tbar:
    
    loss = train(episode)
    if episode % dev_interval == 0:
        acc = dev(episode)
        if acc > best_acc:
            print('Better acc! Saving model! -> {:.4f}'.format(acc))
            best_acc = acc
    tbar.set_postfix(loss=loss)  

  0%|          | 0/9999 [00:00<?, ?it/s]Could not load symbol cublasGetSmCountTarget from libcublas.so.11. Error: /usr/local/cuda-11.2/lib64/libcublas.so.11: undefined symbol: cublasGetSmCountTarget
  1%|          | 102/9999 [00:10<27:36,  5.98it/s, loss=tensor(0.5045, device='cuda:0', grad_fn=<MeanBackward0>)]

Better acc! Saving model! -> 0.5389


  2%|▏         | 202/9999 [00:21<29:25,  5.55it/s, loss=tensor(0.5000, device='cuda:0', grad_fn=<MeanBackward0>)]

Better acc! Saving model! -> 0.5426


 48%|████▊     | 4801/9999 [08:33<18:32,  4.67it/s, loss=tensor(0.4628, device='cuda:0', grad_fn=<MeanBackward0>)]

Better acc! Saving model! -> 0.5569


 49%|████▉     | 4902/9999 [08:43<16:58,  5.00it/s, loss=tensor(0.5104, device='cuda:0', grad_fn=<MeanBackward0>)]

Better acc! Saving model! -> 0.5618


 54%|█████▍    | 5401/9999 [09:37<14:37,  5.24it/s, loss=tensor(0.4488, device='cuda:0', grad_fn=<MeanBackward0>)]

Better acc! Saving model! -> 0.5996


 58%|█████▊    | 5802/9999 [10:20<11:47,  5.94it/s, loss=tensor(0.1758, device='cuda:0', grad_fn=<MeanBackward0>)]

Better acc! Saving model! -> 0.6111


 59%|█████▉    | 5902/9999 [10:30<11:28,  5.95it/s, loss=tensor(0.3463, device='cuda:0', grad_fn=<MeanBackward0>)]

Better acc! Saving model! -> 0.6309


 60%|██████    | 6002/9999 [10:41<11:10,  5.96it/s, loss=tensor(0.3233, device='cuda:0', grad_fn=<MeanBackward0>)]

Better acc! Saving model! -> 0.6477


 61%|██████    | 6101/9999 [10:52<12:09,  5.34it/s, loss=tensor(0.2904, device='cuda:0', grad_fn=<MeanBackward0>)]

Better acc! Saving model! -> 0.6640


 64%|██████▍   | 6400/9999 [11:24<13:59,  4.29it/s, loss=tensor(0.1120, device='cuda:0', grad_fn=<MeanBackward0>)]

Better acc! Saving model! -> 0.6786


 66%|██████▌   | 6602/9999 [11:45<11:14,  5.04it/s, loss=tensor(0.3230, device='cuda:0', grad_fn=<MeanBackward0>)]

Better acc! Saving model! -> 0.6790


 69%|██████▉   | 6900/9999 [12:18<12:16,  4.21it/s, loss=tensor(0.2051, device='cuda:0', grad_fn=<MeanBackward0>)]

Better acc! Saving model! -> 0.6795


 72%|███████▏  | 7202/9999 [12:50<08:14,  5.65it/s, loss=tensor(0.2648, device='cuda:0', grad_fn=<MeanBackward0>)]

Better acc! Saving model! -> 0.6816


 77%|███████▋  | 7701/9999 [13:43<07:12,  5.32it/s, loss=tensor(0.1742, device='cuda:0', grad_fn=<MeanBackward0>)]

Better acc! Saving model! -> 0.6869


 78%|███████▊  | 7802/9999 [13:54<06:29,  5.64it/s, loss=tensor(0.1988, device='cuda:0', grad_fn=<MeanBackward0>)]

Better acc! Saving model! -> 0.6976


 93%|█████████▎| 9302/9999 [16:35<02:06,  5.52it/s, loss=tensor(0.1440, device='cuda:0', grad_fn=<MeanBackward0>)]

Better acc! Saving model! -> 0.6981


 97%|█████████▋| 9701/9999 [17:18<00:54,  5.45it/s, loss=tensor(0.1654, device='cuda:0', grad_fn=<MeanBackward0>)]

Better acc! Saving model! -> 0.7183


100%|██████████| 9999/9999 [17:49<00:00,  9.35it/s, loss=tensor(0.3224, device='cuda:0', grad_fn=<MeanBackward0>)]


In [21]:
torch.save(model.state_dict(), f'fewshot_model_{support}.bin')

In [23]:
! ls -alh *.bin

-rw-rw-r-- 1 jkfirst jkfirst 62M  8월  2 03:20 fewshot_model_5.bin
-rwxrwxr-x 1 jkfirst jkfirst 62M  4월 28 05:48 fewshot_model.bin


### load model

In [29]:
support = 5
criterion = Criterion(way=2, shot=support)

In [30]:
model = FewShotInduction(C=2,
                         S=support,
                         vocab_size=len(tokenizer),
                         embed_size=300,
                         hidden_size=128,
                         d_a=64,
                         iterations=3,
                         outsize=100)

In [31]:
model.load_state_dict(torch.load('./fewshot_model.bin', map_location='cpu'))

<All keys matched successfully>

In [32]:
def test():
    model.eval()
    correct = 0.
    count = 0.
    for i in range(100):
        data, target = test_dataloader.get_batch_test()
        data = data
        target = target
        predict = model(data)
        _, acc = criterion(predict, target)
        amount = len(target) - support * 2
        correct += acc * amount
        count += amount
        
    acc = correct / count
    print('Test Acc: {}'.format(acc))
    return acc

In [33]:
test()

Test Acc: 0.6952757239341736


tensor(0.6953)