In [1]:
import os
import torch
from model import FewShotInduction
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer
from glob import glob
from tqdm import tqdm
from torch import optim
from torch.nn.utils.rnn import pad_sequence
from criterion import Criterion

In [2]:
data_path = 'Amazon_few_shot'

In [3]:
# 반드시 do_lower_case=True로 해야 한다.
# bert-base-uncased는 영어 데이터를 소문자로 변환해서 학습한 모델이기 때문이다.
tokenizer = BertTokenizer.from_pretrained('./bert-base-uncased', do_lower_case=True)

### make dataset and dataloader

In [4]:
class AmazonDataset():
    def __init__(self, data_path, tokenizer, dtype):
        self.data_path = data_path
        self.tokenizer = tokenizer
        with open(f'{dtype}.list', 'r') as f:
            self.categories = [oneline.rstrip() for oneline in f]
        self.support_dataset = {}
        self.dataset = {}
        for category in tqdm(self.categories, desc='reading categories'):
            self.dataset[category] = {
                'neg': self.get_data(category, 'neg', dtype),
                'pos': self.get_data(category, 'pos', dtype)
            }
        
        if dtype == 'test' or dtype == 'dev':
            for category in tqdm(self.categories, desc='reading categories for support'):
                self.support_dataset[category] = {
                    'neg': self.get_data(category, 'neg', 'train'),
                    'pos': self.get_data(category, 'pos', 'train'),
                }
        
    def read_files(self, category, label, dtype):
        data = {
            'text': [],
            'label': []
        }
        for t in ['t2', 't4', 't5']:
            filename = f'{category}.{t}.{dtype}'
            with open(os.path.join(self.data_path, filename), 'r') as f:
                for oneline in f:
                    oneline = oneline.rstrip()
                    text = oneline[:-2]
                    if int(oneline[-2:]) == 1 and label == 'pos':
                        tensor = self.tokenizer(text, return_tensors='pt')
                        data['text'].append(tensor['input_ids'][0])
                        data['label'].append(1)
                    elif int(oneline[-2:]) == -1 and label == 'neg':
                        tensor = self.tokenizer(text, return_tensors='pt')
                        data['text'].append(tensor['input_ids'][0])
                        data['label'].append(0)
        data['label'] = torch.tensor(data['label'])
        return data
    
    def get_data(self, category, label, dtype):
        data = self.read_files(category, label, dtype)
        return data

In [5]:
train_dataset = AmazonDataset(data_path, tokenizer, 'train')
dev_dataset = AmazonDataset(data_path, tokenizer, 'dev')
test_dataset = AmazonDataset(data_path, tokenizer, 'test')

reading categories: 100%|██████████| 14/14 [02:58<00:00, 12.76s/it]
reading categories: 100%|██████████| 5/5 [00:01<00:00,  4.90it/s]
reading categories for support: 100%|██████████| 5/5 [00:08<00:00,  1.69s/it]
reading categories: 100%|██████████| 4/4 [00:20<00:00,  5.02s/it]
reading categories for support: 100%|██████████| 4/4 [00:00<00:00, 15.27it/s]


In [6]:
def pad_text(a_text, b_text):
    a_text_len = a_text.shape[1]
    b_text_len = b_text.shape[1]

    if a_text_len > b_text_len:
        b_text = torch.cat([b_text, torch.zeros(b_text.shape[0], a_text_len-b_text_len).long()], dim=1)
    else:
        a_text = torch.cat([a_text, torch.zeros(a_text.shape[0], b_text_len-a_text_len).long()], dim=1)
        
    return a_text, b_text

In [7]:
class AmazonDataLoader():
    def __init__(self, dataset, batch_size, n_support):
        assert n_support % 2 == 0, 'n_support should be multiple of 2'
        self.dataset = dataset
        self.batch_size = batch_size
        self.n_support = n_support
        self.neg_idx = {k:0 for k in dataset.dataset}
        self.pos_idx = {k:0 for k in dataset.dataset}
        self.neg_len = {k:len(dataset.dataset[k]['neg']['text']) for k in dataset.dataset}
        self.pos_len = {k:len(dataset.dataset[k]['pos']['text']) for k in dataset.dataset}
        self.neg = {k:dataset.dataset[k]['neg'] for k in dataset.dataset}
        self.pos = {k:dataset.dataset[k]['pos'] for k in dataset.dataset}
        self.idx = 0
        self.categories = [k for k in dataset.dataset]
        
        # prepare for test dataset, support dataset should come from "*.train"
        self.neg_support_idx = {}
        self.pos_support_idx = {}
        self.neg_support_len = {}
        self.pos_support_len = {}
        if self.dataset.support_dataset:
            self.neg_support_idx = {k:0 for k in self.dataset.support_dataset}
            self.pos_support_idx = {k:0 for k in self.dataset.support_dataset}
            self.neg_support_len = {k:len(self.dataset.support_dataset[k]['neg']['text']) for k in self.dataset.support_dataset}
            self.pos_support_len = {k:len(self.dataset.support_dataset[k]['pos']['text']) for k in self.dataset.support_dataset}
        
    def get_batch(self):
        category = self.categories[self.idx % len(self.categories)]
        neg = self.neg[category]
        pos = self.pos[category]
        neg_start_idx = self.neg_idx[category] % self.neg_len[category]
        pos_start_idx = self.pos_idx[category] % self.pos_len[category]
        
        # prepare negative/positive dataset
        neg_text = neg['text'][neg_start_idx:neg_start_idx+(self.batch_size//2)]
        pos_text = pos['text'][pos_start_idx:pos_start_idx+(self.batch_size//2)]
        neg_label = neg['label'][neg_start_idx:neg_start_idx+(self.batch_size//2)]
        pos_label = pos['label'][pos_start_idx:pos_start_idx+(self.batch_size//2)]
        self.neg_idx[category] += (self.batch_size//2)
        self.pos_idx[category] += (self.batch_size//2)
        
        if len(neg_text) + len(pos_text) != self.batch_size:
            return self.get_batch()
            
        # padding text dataset
        neg_text = pad_sequence([n for n in neg_text], batch_first=True)
        pos_text = pad_sequence([p for p in pos_text], batch_first=True)
        neg_text, pos_text = pad_text(neg_text, pos_text)
            
        # prepare support/query text
        neg_support_text = neg_text[:self.n_support//2]
        pos_support_text = pos_text[:self.n_support//2]
        neg_query_text = neg_text[self.n_support//2:]
        pos_query_text = pos_text[self.n_support//2:]
        
        # prepare support/query label
        neg_support_label = neg_label[:self.n_support//2]
        pos_support_label = pos_label[:self.n_support//2]
        neg_query_label = neg_label[self.n_support//2:]
        pos_query_label = pos_label[self.n_support//2:]
        
        # merge support/query text
        support_text = torch.cat([neg_support_text, pos_support_text], dim=0)
        query_text = torch.cat([neg_query_text, pos_query_text], dim=0)
        
        # merge support/query label
        support_label = torch.cat([neg_support_label, pos_support_label], dim=0)
        query_label = torch.cat([neg_query_label, pos_query_label], dim=0)
        
        # make data and label
        data = torch.cat([support_text, query_text], dim=0)
        label = torch.cat([support_label, query_label], dim=0)
        
        # increase category index
        self.idx += 1
        return data, label
    
    def get_batch_test(self):
        assert self.dataset.support_dataset, 'support_dataset is empty'
        
        category = self.categories[self.idx % len(self.categories)]
        neg = self.neg[category]
        pos = self.pos[category]
        neg_query_start_idx = self.neg_idx[category] % self.neg_len[category]
        pos_query_start_idx = self.pos_idx[category] % self.pos_len[category]
        neg_support_start_idx = self.neg_support_idx[category] % self.neg_support_len[category]
        pos_support_start_idx = self.pos_support_idx[category] % self.pos_support_len[category]
        
        # prepare negative/positive support dataset from support_dataset
        category_suuport_dataset = self.dataset.support_dataset[category]
        neg_support_text = category_suuport_dataset['neg']['text'][neg_support_start_idx:neg_support_start_idx+self.n_support//2]
        pos_support_text = category_suuport_dataset['pos']['text'][pos_support_start_idx:pos_support_start_idx+self.n_support//2]
        neg_support_label = category_suuport_dataset['neg']['label'][neg_support_start_idx:neg_support_start_idx+self.n_support//2]
        pos_support_label = category_suuport_dataset['pos']['label'][pos_support_start_idx:pos_support_start_idx+self.n_support//2]
        self.neg_support_idx[category] += (self.n_support//2)
        self.pos_support_idx[category] += (self.n_support//2)
        
        # prepare negative/positive query dataset
        neg_query_text = neg['text'][neg_query_start_idx:neg_query_start_idx+(self.batch_size//2 - self.n_support//2)]
        pos_query_text = pos['text'][pos_query_start_idx:pos_query_start_idx+(self.batch_size//2 - self.n_support//2)]
        neg_query_label = neg['label'][neg_query_start_idx:neg_query_start_idx+(self.batch_size//2 - self.n_support//2)]
        pos_query_label = pos['label'][pos_query_start_idx:pos_query_start_idx+(self.batch_size//2 - self.n_support//2)]
        self.neg_idx[category] += (self.batch_size//2 - self.n_support//2)
        self.pos_idx[category] += (self.batch_size//2 - self.n_support//2)
        
        # padding support text dataset
        if self.n_support:
            neg_support_text = pad_sequence([n for n in neg_support_text], batch_first=True)
            pos_support_text = pad_sequence([n for n in pos_support_text], batch_first=True)
            neg_support_text, pos_support_text = pad_text(neg_support_text, pos_support_text)
        else:
            neg_support_text = torch.tensor([[]])
            pos_support_text = torch.tensor([[]])
            
        # padding text dataset
        neg_query_text = pad_sequence([n for n in neg_query_text], batch_first=True)
        pos_query_text = pad_sequence([p for p in pos_query_text], batch_first=True)
        neg_query_text, pos_query_text = pad_text(neg_query_text, pos_query_text)

        # concatenating support/query text dataset
        support_text = torch.cat([neg_support_text, pos_support_text], dim=0)
        query_text = torch.cat([neg_query_text, pos_query_text], dim=0)
        support_text, query_text = pad_text(support_text, query_text)

        # make final data and label
        if self.n_support:
            data = torch.cat([support_text, query_text], dim=0)
        else:
            data = query_text
        label = torch.cat([neg_support_label, pos_support_label, neg_query_label, pos_query_label], dim=0)
        return data, label

In [8]:
support = 5

In [9]:
train_dataloader = AmazonDataLoader(train_dataset, batch_size=64, n_support=support*2)
dev_dataloader = AmazonDataLoader(dev_dataset, batch_size=64, n_support=support*2)
test_dataloader = AmazonDataLoader(test_dataset, batch_size=64, n_support=support*2)

In [10]:
for i in range(10):
    d, l = train_dataloader.get_batch()
    print(d.shape, l.float().mean())

torch.Size([64, 149]) tensor(0.5000)
torch.Size([64, 460]) tensor(0.5000)
torch.Size([64, 254]) tensor(0.5000)
torch.Size([64, 262]) tensor(0.5000)
torch.Size([64, 1283]) tensor(0.5000)
torch.Size([64, 1658]) tensor(0.5000)
torch.Size([64, 613]) tensor(0.5000)
torch.Size([64, 359]) tensor(0.5000)
torch.Size([64, 530]) tensor(0.5000)
torch.Size([64, 602]) tensor(0.5000)


In [11]:
for i in range(10):
    d, l = dev_dataloader.get_batch_test()
    print(d.shape, l.float().mean())

torch.Size([64, 327]) tensor(0.5000)
torch.Size([55, 181]) tensor(0.5818)
torch.Size([64, 198]) tensor(0.5000)
torch.Size([46, 295]) tensor(0.6957)
torch.Size([64, 197]) tensor(0.5000)
torch.Size([64, 276]) tensor(0.5000)
torch.Size([55, 186]) tensor(0.5818)
torch.Size([64, 270]) tensor(0.5000)
torch.Size([26, 130]) tensor(0.4615)
torch.Size([64, 327]) tensor(0.5000)


In [12]:
for i in range(10):
    d, l = test_dataloader.get_batch_test()
    print(d.shape, l.float().mean())

torch.Size([64, 743]) tensor(0.5000)
torch.Size([64, 841]) tensor(0.5000)
torch.Size([64, 1386]) tensor(0.5000)
torch.Size([64, 706]) tensor(0.5000)
torch.Size([64, 1026]) tensor(0.5000)
torch.Size([64, 1126]) tensor(0.5000)
torch.Size([64, 1116]) tensor(0.5000)
torch.Size([64, 1333]) tensor(0.5000)
torch.Size([64, 568]) tensor(0.5000)
torch.Size([64, 570]) tensor(0.5000)


In [13]:
# class AmazonDataLoader():
#     def __init__(self, dataset, amount, batch_size):
#         self.amount = amount
#         self.dataset = dataset
#         self.batch_size = batch_size
#         self.categories = list(dataset.dataset.keys())
#         self.category_idx = 0
#         self.indices_per_category = {
#             category: 0 for category in self.categories
#         }
#         self.n_data_per_category = {
#             category: len(dataset.dataset[category]['indice']) for category in self.categories
#         }
#     def get_batch(self):
#         idx = self.category_idx % len(self.categories)
#         category = self.categories[idx]
#         n_data_per_category = self.n_data_per_category[category]
#         start_idx = self.indices_per_category[category] % n_data_per_category
#         indice = self.dataset.dataset[category]['indice'][start_idx:start_idx+self.batch_size]
#         labels = self.dataset.dataset[category]['label'][start_idx:start_idx+self.batch_size]
#         self.indices_per_category[category] += self.batch_size
#         self.category_idx += 1
        
#         if len(indice) != self.batch_size:
#             return self.get_batch()
        
#         indice = pad_sequence(indice, batch_first=True)
#         return indice, labels, category

### define model

In [14]:
model = FewShotInduction(C=2,
                         S=support,
                         vocab_size=len(tokenizer),
                         embed_size=300,
                         hidden_size=128,
                         d_a=64,
                         iterations=3,
                         outsize=100)
model = model.cuda()

In [15]:
len(tokenizer)

30522

In [16]:
optimizer = optim.Adam(model.parameters(), lr=float(1e-4))

In [17]:
criterion = Criterion(way=2, shot=support)

In [18]:
def train(episode):
    model.train()
    data, target = train_dataloader.get_batch()
    data = data.cuda()
    target = target.cuda()
    optimizer.zero_grad()
    predict = model(data)
    loss, acc = criterion(predict, target)
    loss.backward()
    optimizer.step()
    return loss

In [19]:
def dev(episode):
    model.eval()
    correct = 0.
    count = 0.
    for i in range(100):
        data, target = dev_dataloader.get_batch_test()
        data = data.cuda()
        target = target.cuda()
        predict = model(data)
        _, acc = criterion(predict, target)
        amount = len(target) - support * 2
        correct += acc * amount
        count += amount
    acc = correct / count
    return acc

In [20]:
def test():
    model.eval()
    correct = 0.
    count = 0.
    for i in range(100):
        data, target = test_dataloader.get_batch_test()
        data = data.cuda()
        target = target.cuda()
        predict = model(data)
        _, acc = criterion(predict, target)
        amount = len(target) - support * 2
        correct += acc * amount
        count += amount
        
    acc = correct / count
    print('Test Acc: {}'.format(acc))
    return acc

In [21]:
dev_interval = 100
best_acc = -1.0

In [22]:
tbar = tqdm(range(1, 10000))
for episode in tbar:
    
    loss = train(episode)
    if episode % dev_interval == 0:
        acc = dev(episode)
        if acc > best_acc:
            print('Better acc! Saving model! -> {:.4f}'.format(acc))
            best_acc = acc
    tbar.set_postfix(loss=loss)  

  1%|          | 101/9999 [00:24<1:39:28,  1.66it/s, loss=tensor(0.4911, device='cuda:0', grad_fn=<MeanBackward0>)]

Better acc! Saving model! -> 0.5374


  2%|▏         | 201/9999 [00:49<1:41:05,  1.62it/s, loss=tensor(0.5000, device='cuda:0', grad_fn=<MeanBackward0>)]

Better acc! Saving model! -> 0.5421


  5%|▌         | 500/9999 [02:02<2:03:14,  1.28it/s, loss=tensor(0.4907, device='cuda:0', grad_fn=<MeanBackward0>)]

Better acc! Saving model! -> 0.5428


 32%|███▏      | 3200/9999 [12:54<1:27:25,  1.30it/s, loss=tensor(0.5259, device='cuda:0', grad_fn=<MeanBackward0>)]

Better acc! Saving model! -> 0.5548


 36%|███▌      | 3601/9999 [14:31<1:03:33,  1.68it/s, loss=tensor(0.3047, device='cuda:0', grad_fn=<MeanBackward0>)]

Better acc! Saving model! -> 0.6022


 37%|███▋      | 3701/9999 [14:56<1:05:37,  1.60it/s, loss=tensor(0.4654, device='cuda:0', grad_fn=<MeanBackward0>)]

Better acc! Saving model! -> 0.6230


 39%|███▉      | 3900/9999 [15:44<1:18:19,  1.30it/s, loss=tensor(0.4810, device='cuda:0', grad_fn=<MeanBackward0>)]

Better acc! Saving model! -> 0.6518


 40%|████      | 4000/9999 [16:07<1:18:58,  1.27it/s, loss=tensor(0.2676, device='cuda:0', grad_fn=<MeanBackward0>)]

Better acc! Saving model! -> 0.6539


 41%|████      | 4100/9999 [16:32<1:18:33,  1.25it/s, loss=tensor(0.2664, device='cuda:0', grad_fn=<MeanBackward0>)]

Better acc! Saving model! -> 0.6652


 42%|████▏     | 4200/9999 [16:56<1:16:40,  1.26it/s, loss=tensor(0.2961, device='cuda:0', grad_fn=<MeanBackward0>)]

Better acc! Saving model! -> 0.6730


 43%|████▎     | 4300/9999 [17:20<1:14:57,  1.27it/s, loss=tensor(0.1779, device='cuda:0', grad_fn=<MeanBackward0>)]

Better acc! Saving model! -> 0.6873


 45%|████▌     | 4500/9999 [18:09<1:10:59,  1.29it/s, loss=tensor(0.1996, device='cuda:0', grad_fn=<MeanBackward0>)]

Better acc! Saving model! -> 0.6883


 50%|█████     | 5001/9999 [20:11<50:16,  1.66it/s, loss=tensor(0.1750, device='cuda:0', grad_fn=<MeanBackward0>)]  

Better acc! Saving model! -> 0.6964


 51%|█████     | 5100/9999 [20:35<1:05:25,  1.25it/s, loss=tensor(0.3406, device='cuda:0', grad_fn=<MeanBackward0>)]

Better acc! Saving model! -> 0.7037


 59%|█████▉    | 5900/9999 [23:48<53:08,  1.29it/s, loss=tensor(0.0876, device='cuda:0', grad_fn=<MeanBackward0>)]  

Better acc! Saving model! -> 0.7123


 60%|██████    | 6000/9999 [24:13<51:30,  1.29it/s, loss=tensor(0.1551, device='cuda:0', grad_fn=<MeanBackward0>)]

Better acc! Saving model! -> 0.7133


 62%|██████▏   | 6200/9999 [25:01<50:09,  1.26it/s, loss=tensor(0.2443, device='cuda:0', grad_fn=<MeanBackward0>)]

Better acc! Saving model! -> 0.7195


 65%|██████▌   | 6501/9999 [26:14<34:55,  1.67it/s, loss=tensor(0.1500, device='cuda:0', grad_fn=<MeanBackward0>)]

Better acc! Saving model! -> 0.7287


 67%|██████▋   | 6700/9999 [27:02<41:29,  1.33it/s, loss=tensor(0.1373, device='cuda:0', grad_fn=<MeanBackward0>)]

Better acc! Saving model! -> 0.7345


 76%|███████▌  | 7600/9999 [30:41<32:03,  1.25it/s, loss=tensor(0.2500, device='cuda:0', grad_fn=<MeanBackward0>)]

Better acc! Saving model! -> 0.7393


 82%|████████▏ | 8200/9999 [33:06<23:03,  1.30it/s, loss=tensor(0.0173, device='cuda:0', grad_fn=<MeanBackward0>)]

Better acc! Saving model! -> 0.7446


 84%|████████▍ | 8400/9999 [33:55<21:26,  1.24it/s, loss=tensor(0.2435, device='cuda:0', grad_fn=<MeanBackward0>)]

Better acc! Saving model! -> 0.7486


 90%|█████████ | 9000/9999 [36:20<13:24,  1.24it/s, loss=tensor(0.0522, device='cuda:0', grad_fn=<MeanBackward0>)]

Better acc! Saving model! -> 0.7496


 92%|█████████▏| 9200/9999 [37:09<10:15,  1.30it/s, loss=tensor(0.0444, device='cuda:0', grad_fn=<MeanBackward0>)]

Better acc! Saving model! -> 0.7497


 93%|█████████▎| 9301/9999 [37:33<07:19,  1.59it/s, loss=tensor(0.2773, device='cuda:0', grad_fn=<MeanBackward0>)]

Better acc! Saving model! -> 0.7528


 94%|█████████▍| 9400/9999 [37:57<07:49,  1.28it/s, loss=tensor(0.0272, device='cuda:0', grad_fn=<MeanBackward0>)]

Better acc! Saving model! -> 0.7531


 98%|█████████▊| 9800/9999 [39:33<02:45,  1.20it/s, loss=tensor(0.2823, device='cuda:0', grad_fn=<MeanBackward0>)]

Better acc! Saving model! -> 0.7546


 99%|█████████▉| 9900/9999 [39:57<01:16,  1.30it/s, loss=tensor(0.2743, device='cuda:0', grad_fn=<MeanBackward0>)]

Better acc! Saving model! -> 0.7551


100%|██████████| 9999/9999 [40:19<00:00,  4.13it/s, loss=tensor(0.2489, device='cuda:0', grad_fn=<MeanBackward0>)]


In [23]:
torch.save(model.state_dict(), f'fewshot_model_{support}.bin')

In [24]:
! ls -alh *.bin

-rw-r--r-- 1 jkfirst jkfirst 62M 10월  7 13:11 fewshot_model.bin
-rw-r--r-- 1 jkfirst jkfirst 62M 10월  7 13:11 fewshot_model_0.bin
-rw-rw-r-- 1 jkfirst jkfirst 62M 10월  7 14:44 fewshot_model_5.bin


In [25]:
test_dataset.categories

['books', 'dvd', 'electronics', 'kitchen_housewares']

In [26]:
dev_dataset.categories

['grocery',
 'office_products',
 'outdoor_living',
 'gourmet_food',
 'jewelry_watches']

In [27]:
train_dataset.categories

['apparel',
 'automotive',
 'baby',
 'beauty',
 'camera_photo',
 'cell_phones_service',
 'computer_video_games',
 'health_personal_care',
 'magazines',
 'music',
 'software',
 'sports_outdoors',
 'toys_games',
 'video']

In [28]:
test()

Test Acc: 0.6549707651138306


tensor(0.6550, device='cuda:0')

In [28]:
test_dataset.categories

['books', 'dvd', 'electronics', 'kitchen_housewares']

In [29]:
n_neg = 0
n_pos = 0
for c in test_dataset.categories:
    n_neg += len(test_dataset.dataset[c]['neg']['label'])
    n_pos += len(test_dataset.dataset[c]['pos']['label'])

In [30]:
print(n_neg, n_pos)

1987 7178


In [29]:
from torch.nn.modules.loss import _Loss


class Criterion_(_Loss):
    def __init__(self, way=2, shot=5):
        super(Criterion_, self).__init__()
        self.amount = way * shot

    def forward(self, probs, target, return_pred_label=False):  # (Q,C) (Q)
        target = target[self.amount:]
        target_onehot = torch.zeros_like(probs)
        #print('** sum of probs/target_onehot: {} {}'.format(probs.sum(), target_onehot.sum()))
        target_onehot = target_onehot.scatter(1, target.reshape(-1, 1), 1)
        loss = torch.mean((probs - target_onehot) ** 2)
        pred = torch.argmax(probs, dim=1)
        acc = torch.sum(target == pred).float() / target.shape[0]
        #print('** acc: {}'.format(acc))

        if return_pred_label:
            return loss, acc, pred, target
        else:
            return loss, acc

In [30]:
criterion_ = Criterion_(way=2, shot=support)

In [35]:
from sklearn.metrics import confusion_matrix
import numpy as np

In [36]:
def test_():
    model.eval()
    correct = 0.
    count = 0.
    p_list = []
    l_list = []
    r_list = []
    for i in range(100):
        data, target = test_dataloader.get_batch_test()
        data = data.cuda()
        target = target.cuda()
        predict = model(data)
        _, acc, p, l = criterion_(predict, target, return_pred_label=True)
        amount = len(target) - 5 * 2    # 5 = support
        correct += acc * amount
        count += amount
        
        r = [1 if np.random.random() < 0.5 else 0 for _ in range(len(l))]
        p_list.extend(list(p.cpu().numpy()))
        l_list.extend(list(l.cpu().numpy()))
        r_list.extend(r)
    acc = correct / count
    print('Test Acc: {}'.format(acc))
    mat = confusion_matrix(l_list, p_list)
    print(mat)
    
    rmat = confusion_matrix(l_list, r_list)
    print(rmat)
    
    return acc

In [37]:
test_()

Test Acc: 0.6597026586532593
[[1217 1408]
 [ 400 2288]]
[[1322 1303]
 [1343 1345]]


tensor(0.6597, device='cuda:0')

In [70]:
1707+921+658+2012

5298

In [72]:
(1293+1335)/(1293+1335+1300+1370)

0.4960362400906002