# 试卷切分

In [1]:
import os
import torch
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from load_data import VecDataset
from trainer import MyTrainer
from model import PaperSegModel
from utils import get_logger, ROOT_DIR

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
# 以DisneQNet为例
class Args:
    subject = "math"
    data_path = os.path.join(ROOT_DIR, "data")
    checkpoint_dir = os.path.join(ROOT_DIR, "checkpoint")
    
    pretrained_model_type="disenqnet"
    pretrained_model_dir="/path/to/disenqnet/checkpoint"

    device="cpu"

args = Args()

In [None]:
args.train_data_path = f"{args.data_path}/train/{args.subject}/paper_txt_tagged"
args.valid_data_path = f"{args.data_path}/valid/{args.subject}/paper_txt_tagged"
args.test_data_path = f"{args.data_path}/test/{args.subject}/paper_txt_tagged"

In [None]:
# logger
logfile = f'{args.checkpoint_dir}/train.log'
logger = get_logger(logfile)
# tensorboard
tensorboard_dir = f'{args.checkpoint_dir}/tensorboard'
os.makedirs(tensorboard_dir, exist_ok=True)
tensorboard_writer = SummaryWriter(tensorboard_dir)

## 加载向量数据集

In [None]:
train_set = VecDataset(
                        text_data_dir=args.train_data_path,
                        emb_data_path=args.train_data_path.replace("paper_txt_tagged", "emb.train.pt",
                        mode="train",
                        pretrained_model_type=args.pretrained_model_type,
                        pretrained_model_dir=args.pretrained_model_dir,
                        device=args.device,
                        )
valid_set = VecDataset(
                        text_data_dir=args.valid_data_path,
                        emb_data_path=args.valid_data_path.replace("paper_txt_tagged", "emb.train.pt",
                        mode="valid",
                        pretrained_model_type=args.pretrained_model_type,
                        pretrained_model_dir=args.pretrained_model_dir,
                        paper_i2v=train_set.paper_i2v,
                        device=args.device,
                        )
test_set = VecDataset(
                        text_data_dir=args.test_data_path,
                        emb_data_path=args.test_data_path.replace("paper_txt_tagged", f"paper_emb_{args.pretrained_model_type}{i2v_postfix}") + ".test.pt",
                        mode="test",
                        pretrained_model_type=args.pretrained_model_type,
                        pretrained_model_dir=args.pretrained_model_dir,
                        paper_i2v=train_set.paper_i2v,
                        device=args.device,
                        )
train_loader = DataLoader(train_set, batch_size=4, shuffle=True, num_workers=0, collate_fn=train_set.collcate_fn)
valid_loader = DataLoader(valid_set, batch_size=1, shuffle=False, num_workers=0, collate_fn=valid_set.collcate_fn)
test_loader = DataLoader(test_set, batch_size=1, shuffle=False, num_workers=0, collate_fn=test_set.collcate_fn)

## 加载模型

In [None]:
model = PaperSegModel(
        embed_dim=train_set.embed_dim,
        hidden_dim=256,
        num_layers=2
        )
model = model.to(args.device)
logger.info('prepare model have done!')
# model.save_pretrained(args.checkpoint_dir)

## 训练和评估

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
trainer = MyTrainer(
    args=args,
    model=model,
    optimizer=optimizer,
    logger=logger,
    tensorboard_writer=tensorboard_writer,
)
trainer.train(train_loader, valid_loader)
logger.info("Finish training ... ")

In [None]:
model = PaperSegModel.from_pretrained(args.checkpoint_dir).to(args.device)
trainer = MyTrainer(
    args=args,
    model=model,
    logger=logger,
)
trainer.valid(test_loader)
logger.info("Finish testing ... ")