In [1]:
import json
import os
import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from transformers import AutoTokenizer
import numpy as np
from transformers import RobertaTokenizer, RobertaModel,  RobertaForSequenceClassification
import torch.optim as optim
import tqdm
import  torch.nn.functional as F

In [2]:
device ="cpu"
batch_size = 64

In [3]:
class Seq2SeqDataset(Dataset):
    """
    A Simple Seq2Seq Dataset Implementation
    """
    def __init__(self, fact_filename, romantic_filename,funny_filename, tokenizer, add_bos_token=True, add_eos_token=True):
        data = []
        if fact_filename is not None:
            with open(fact_filename,'r') as f:
                line = f.readline()
                while line:
                    data.append({"source":"","target":line.replace('\n',''),"style":"fact"})
                    line = f.readline()
        if romantic_filename is not None:
            with open(romantic_filename,'r') as f:
                line = f.readline()
                while line:
                    data.append({"source":"","target":line.replace('\n',''),"style":"romantic"})
                    line = f.readline() 
                    
        if funny_filename is not None:
            with open(funny_filename,'r') as f:
                line = f.readline()
                while line:
                    data.append({"source":"","target":line.replace('\n',''),"style":"funny"})
                    line = f.readline()    

        self.data = data
        self.tokenizer = tokenizer
        self.add_bos_token = add_bos_token
        self.add_eos_token = add_eos_token

    def __getitem__(self, index):
        item = self.data[index]
        target_token_ids = self.tokenizer.encode(item["target"], add_special_tokens=False)

        if self.add_bos_token:
            target_token_ids.insert(0, self.tokenizer.bos_token_id)

        if self.add_eos_token:
            target_token_ids.append(self.tokenizer.eos_token_id)


        item["target_token_ids"] = torch.LongTensor(target_token_ids)
        
        if item["style"]=='fact':
            item["source_token_ids"] = [1,0,0]
        elif item["style"]=='romantic':
            item["source_token_ids"] = [0, 1, 0]
        elif item["style"]=='funny':
            item["source_token_ids"] = [0, 0, 1]
        return item

    def __len__(self):
        return len(self.data)

    def collate_fn(self, batch):
        new_batch = {}
        new_batch["source_token_ids"] = torch.tensor([item["source_token_ids"] for item in batch])
        new_batch["target_token_ids"] = pad_sequence(
            [item["target_token_ids"] for item in batch], batch_first=True, padding_value=self.tokenizer.pad_token_id
        )
        new_batch["style"] = [item["style"] for item in batch]
        return new_batch



In [4]:
tokenizer = AutoTokenizer.from_pretrained("roberta-base")

In [58]:
fact_filename = "./funny-GEN.txt"
romantic_filename = "./romantic-GEN.txt"
funny_filename = "./funny-GEN.txt"
gen_dataset = Seq2SeqDataset(fact_filename=fact_filename,romantic_filename=romantic_filename,funny_filename=funny_filename,tokenizer=tokenizer)
gen_dataloader = DataLoader(
            gen_dataset, batch_size=batch_size, shuffle=True, collate_fn=gen_dataset .collate_fn)

In [55]:
# fact_filename = "../StyleCaption/fact-test.txt"
# romantic_filename = "../StyleCaption/romantic-test.txt"
# funny_filename = "../StyleCaption/funny-test.txt"
# gen_dataset = Seq2SeqDataset(fact_filename,romantic_filename,funny_filename , tokenizer)
# gen_dataloader = DataLoader(
#             test_dataset, batch_size=batch_size, shuffle=True, collate_fn=train_dataset.collate_fn)

In [56]:
class_pth = "/home/hqh/Triple-Gan/classify.pth"
model = torch.load(class_pth, map_location=device)
device ="cpu"
batch_size = 64

In [59]:
with torch.no_grad():
    epoch_num = 0 
    epoch_real= 0
    fact_epoch=0
    romantic_epoch = 0
    funny_epoch=0
    for batch in tqdm.tqdm(gen_dataloader):
        sequence = batch['target_token_ids'].to(device)
        label_onehot = batch['source_token_ids'] # n*3,tensor
        label = torch.argmax(label_onehot,dim=-1).to(device) # n tensor
        sequence_logits = model(sequence).logits
        #打印
        pred_true = torch.argmax(sequence_logits,dim=-1) == label
        batch_real = ( pred_true).sum()
        batch_num = len(label)
        epoch_real += batch_real
        epoch_num += len(label)
        #分门别类
        fact_epoch +=  pred_true[(label==0).nonzero().reshape(-1)].sum()
        romantic_epoch +=  pred_true[(label==1).nonzero().reshape(-1)].sum()
        funny_epoch +=  pred_true[(label==2).nonzero().reshape(-1)].sum()
    acc_rate = epoch_real/epoch_num
    print("在验证集上,正确个数：{},总个数：{},准确率:{},fact:{},romantic:{},funny:{}"
      .format(epoch_real,epoch_num,acc_rate,fact_epoch*3/epoch_num,romantic_epoch*3/epoch_num,funny_epoch*3/epoch_num))

100%|██████████| 24/24 [00:11<00:00,  2.10it/s]

在验证集上,正确个数：809,总个数：1536,准确率:0.5266926884651184,fact:0.361328125,romantic:0.677734375,funny:0.541015625



