In [None]:
import torch
import argparse
import os
import logging
import time
from torch import nn
from contextlib import nullcontext
from datasets import load_dataset
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
import torch.distributed as dist
from transformers import BertTokenizer, BertModel, AdamW
#from tqdm.auto import tqdm

In [None]:
# get distributed conf
WORLD_SIZE = int(os.environ.get("WORLD_SIZE", 1))
RANK = int(os.environ.get("RANK", 0))

In [None]:
def parse_args(args):
    print(123)
    import argparse
    parser = argparse.ArgumentParser(description="PyTorch PERT Example")
    parser.add_argument("--batch-size", type=int, default=16, metavar="N",
                        help="input batch size for training (default: 16)")
    parser.add_argument("--test-batch-size", type=int, default=1000, metavar="N",
                        help="input batch size for testing (default: 1000)")
    parser.add_argument("--epochs", type=int, default=1, metavar="N",
                        help="number of epochs to train (default: 10)")
    parser.add_argument("--lr", type=float, default=1e-5, metavar="LR",
                        help="learning rate (default: 0.01)")
    parser.add_argument("--seed", type=int, default=1, metavar="S",
                        help="random seed (default: 1)")
    parser.add_argument("--dataset", type=int, default=1, metavar="D",
                        help="dataset size (default 1 * 9600)")
    parser.add_argument("--save-model", action="store_true", default=False,
                        help="For Saving the current Model")
    parser.add_argument("--local-only", type=str, default="false",
                        help="If set to true, then load model from disk")
    parser.add_argument("--model-path", type=str, default="/ppml/model",
                        help="Where to load model")
    parser.add_argument("--device", type=str, default="cpu",
                    help="Where to train model, default is cpu")
    # Only for test purpose
    parser.add_argument("--load-model", action="store_true", default=False,
                        help="For loading the current model")
    parser.add_argument("--mini-batch", type=int, default=0, metavar="M",
                    help="If set, the PyTorch will conduct M local-batch computation before doing a all_reduce sync")
    parser.add_argument("--log-interval", type=int, default=2, metavar="N",
                    help="how many batches to wait before logging training status")
    parser.add_argument("--log-path", type=str, default="",
                    help="Path to save logs. Print to StdOut if log-path is not set")
    return parser.parse_args(args)

In [None]:
def set_log_path(args):
    print(456)
    if args.log_path == "":
        logging.basicConfig(
        format="%(asctime)s %(levelname)-8s %(message)s",
        datefmt="%Y-%m-%dT%H:%M:%SZ",
        level=logging.DEBUG)
    else:
        logging.basicConfig(
            format="%(asctime)s %(levelname)-8s %(message)s",
            datefmt="%Y-%m-%dT%H:%M:%SZ",
            level=logging.DEBUG,
            filename=args.log_path)

In [None]:
class Dataset(torch.utils.data.Dataset):
    # data_type is actually split, so that we can define dataset for train set/validate set
    def __init__(self, data_type, dataset_load):
        self.dataset_load = dataset_load
        self.data = self.load_data(data_type)

    def load_data(self, data_type):
        tmp_dataset = load_dataset(path='seamew/ChnSentiCorp', split=data_type)
        Data = {}
        for i in range(self.dataset_load):
            for idx, line in enumerate(tmp_dataset):
                sample = line
                Data[idx + i * len(tmp_dataset)] = sample
        return Data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

In [None]:
def get_dataloader(args, train_data, valid_data):
    #init tokenizer
    if args.local_only.lower() == "true":
        checkpoint = args.model_path
        tokenizer = BertTokenizer.from_pretrained(
        checkpoint, model_max_length=512, local_files_only=True)
    else:
        checkpoint = 'hfl/chinese-pert-base'
        tokenizer = BertTokenizer.from_pretrained(checkpoint, model_max_length=512)
    if is_distributed():
        train_sampler = DistributedSampler(
            train_data, num_replicas=WORLD_SIZE, rank=RANK, shuffle=True, drop_last=False, seed=args.seed)
        valid_sampler = DistributedSampler(
            valid_data, num_replicas=WORLD_SIZE, rank=RANK, shuffle=True, drop_last=False, seed=args.seed)
        train_dataloader = DataLoader(
            train_data, batch_size=args.batch_size, collate_fn=lambda x: collate_fn(x, tokenizer), sampler=train_sampler)
        valid_dataloader = DataLoader(
            valid_data, batch_size=args.test_batch_size, collate_fn=lambda x: collate_fn(x, tokenizer), sampler=valid_sampler)
    else:
        train_dataloader = DataLoader(
            train_data, batch_size=args.batch_size, shuffle=True, collate_fn=lambda x: collate_fn(x, tokenizer))
        valid_dataloader = DataLoader(
            valid_data, batch_size=args.test_batch_size, shuffle=True, collate_fn=lambda x: collate_fn(x, tokenizer))
    return train_dataloader, train_dataloader

In [None]:
def should_distribute():
    return dist.is_available() and WORLD_SIZE > 1

In [None]:
def is_distributed():
    return dist.is_available() and dist.is_initialized()

Define dataset, so it is easier to load different split in the dataset

In [None]:
# Return a batch of data, which is used for training
def collate_fn(batch_samples, tokenizer):
    batch_text = []
    batch_label = []
    for sample in batch_samples:
        batch_text.append(sample['text'])
        batch_label.append(int(sample['label']))
    # The tokenizer will make the data to be a good format for our model to understand
    X = tokenizer(
        batch_text,
        padding=True,
        truncation=True,
        return_tensors="pt"
    )
    y = torch.tensor(batch_label)
    return X, y

In [None]:
# define model
class NeuralNetwork(nn.Module):
    def __init__(self, args):
        super(NeuralNetwork, self).__init__()
        if args.local_only.lower() == "true":
            checkpoint = args.model_path
            self.bert_encoder = BertModel.from_pretrained(
                checkpoint, local_files_only=True)
        else:
            checkpoint = 'hfl/chinese-pert-base'
            self.bert_encoder = BertModel.from_pretrained(checkpoint)
        self.classifier = nn.Linear(768, 2)

    def forward(self, x):
        bert_output = self.bert_encoder(**x)
        cls_vectors = bert_output.last_hidden_state[:, 0]
        logits = self.classifier(cls_vectors)
        return logits

In [None]:
# define training loop
def train_loop(args, dataloader, model, loss_fn, optimizer, epoch, total_loss):
    # Set to train mode
    model.train()
    total_dataset = 0
    optimizer.zero_grad(set_to_none=True)
    enumerator = enumerate(dataloader, start=1)
    for batch, (X, y) in enumerator:
        my_context = model.no_sync if WORLD_SIZE > 1 and args.mini_batch > 0 and batch % args.mini_batch != 0 else nullcontext
        with my_context():
            X, y = X.to(args.device), y.to(args.device)
            # Forward pass
            pred = model(X)
            loss = loss_fn(pred, y)
            loss.backward()
            total_loss += loss.item()
        if args.mini_batch == 0 or batch % args.mini_batch == 0:
            optimizer.step()
            optimizer.zero_grad(set_to_none=True)

        total_dataset += args.batch_size
        from torch.utils.tensorboard import SummaryWriter   
        writer = SummaryWriter('/ppml/test/pert.log')
        writer.add_scalar('loss',loss.item(), (epoch - 1) * len(dataloader) + batch)
        if batch % args.log_interval == 0:
            msg = "Train Epoch: {} [{}/{} ({:.0f}%)]\tloss={:.4f}".format(
                epoch, batch, len(dataloader),
                100. * batch / len(dataloader), loss.item())
            logging.info(msg)

    return total_loss, total_dataset

In [None]:
# define test loop to get acc
def test_loop(args, dataloader, model, mode='Test'):
    assert mode in ['Valid', 'Test']
    size = len(dataloader.dataset)
    correct = 0

    model.eval()
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(args.device), y.to(args.device)
            pred = model(X)
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()

    #correct /= size
    #correct *= WORLD_SIZE
    correct = correct / (size / WORLD_SIZE)
    print(f"{mode} Accuracy: {(100*correct):>0.1f}%\n")
    return correct

In [None]:
def do_train(args, train_dataloader, valid_dataloader, model, loss_fn, optimizer):
    total_loss = 0.
    best_acc = 0.
    total_time = 0.
    total_throughput = 0.

    for t in range(args.epochs):
        print(f"Epoch {t+1}/{args.epochs + 1}\n-------------------------------")
#         if is_distributed():
#             # set seed
#             train_dataloader.sampler.set_epoch(t)
#             valid_dataloader.sampler.set_epoch(t)
        start = time.perf_counter()
        # start to train
        total_loss, total_dataset = train_loop(
            args, train_dataloader, model, loss_fn, optimizer, t+1, total_loss)
        end = time.perf_counter()
        print(f"Epoch {t+1}/{args.epochs + 1} Elapsed time:",
              end - start, flush=True)
        print(f"Epoch {t+1}/{args.epochs + 1} Processed dataset length:",
              total_dataset, flush=True)
        msg = "Epoch {}/{} Throughput: {: .4f}".format(
            t+1, args.epochs+1, 1.0 * total_dataset / (end-start))
        total_time += (end - start)
        total_throughput += total_dataset
        print(msg, flush=True)
        # to valid acc
        valid_acc = test_loop(args, valid_dataloader, model, mode='Valid')

    print("[INFO]Finish all test", flush=True)
    msg = "[INFO]Average training time per epoch: {: .4f}".format(total_time / args.epochs)
    print(msg, flush=True)

    msg = "[INFO]Average throughput per epoch: {: .4f}".format(total_throughput / total_time)
    print(msg, flush=True)

In [None]:
def main(args=None):
    # parse args
    args = parse_args(args)
    print(args)
    # set log_path
    set_log_path(args)
    
    # init data_loader
    torch.manual_seed(args.seed) # set seed
    
    # init pytorch distributed network if need
    if should_distribute():
        print("Using distributed PyTorch with {} backend".format(
            "GLOO"), flush=True)
        dist.init_process_group(backend=dist.Backend.GLOO)

    # Load train and valid data
    print("[INFO]Before data get loaded", flush=True)
    train_data = Dataset('train', args.dataset)
    print("######train data length:", len(train_data.data), flush=True)
    valid_data = Dataset('validation', 1)

    train_dataloader, valid_dataloader = get_dataloader(args, train_data, valid_data)
    print("[INFO]Data get loaded successfully", flush=True)

    #init model
    model = NeuralNetwork(args).to(args.device)
    print("what happen")
    # load pre-train model
    if (args.load_model):
        model.load_state_dict(torch.load('./pert.bin'))
    # local or distributed model
    if is_distributed():
        Distributor = nn.parallel.DistributedDataParallel
        model = Distributor(model, find_unused_parameters=True)
    
    # set loss function and optimizer
    loss_fn = nn.CrossEntropyLoss()
    optimizer = AdamW(model.parameters(), lr=args.lr)
    
    # train epoch
    save_model = do_train(args, train_dataloader, valid_dataloader, model, loss_fn, optimizer)

    # save model and exit
    if (args.save_model):
        torch.save(save_model.state_dict(), "pert.bin")
    if is_distributed():
        dist.destroy_process_group()

In [None]:
if __name__ == "__main__":
    import os
    os.environ['HF_DATASETS_OFFLINE'] = '1'
    import ppml_conf
    local_conf = ppml_conf.PPMLConf(k8s_enabled = False) \
    .set("epoch", "2") \
    .set("log-interval", "20") \
    .set("test-batch-size", "16") \
    .set("batch-size", "16") \
    .set("local-only", "true") \
    .set("dataset", "1") \
    .set("model-path", "/ppml/model")

    args1=local_conf.conf_to_args()
        
#     args=["--epoch", "2",
#          "--log-interval", "20",
#           "--test-batch-size", "16", 
#           "--batch-size", "16",
#           "--local-only", "true",
#           "--dataset", "1",
#           "--model-path", "/ppml/model"
#          ]
#     print(args)
    main(args1)
    sys.exit()

In [None]:
import k8s_deployment

import ppml_conf
k8s_conf = ppml_conf.PPMLConf(k8s_enabled = True, sgx_enabled=False) \
.set_k8s_env("GLOO_TCP_IFACE", "ens803f0") \
.set_k8s_env("HF_DATASETS_OFFLINE", "1") \
.set_k8s("nnodes", "2") \
.set_k8s("pod_cpu", "13") \
.set_k8s("pod_memory", "64G") \
.set_k8s("pod_epc_memory", "68719476736")

# set_volumn_host("volume_name,host_path")
# set_volumn_nfs("volume_name, nfs_server, nfs_path")
# set_volumn_mount("mount_path,volume_name")
k8s_conf \
.set_volume_nfs("source-code", "172.168.0.205","/mnt/sdb/disk1/nfsdata/wangjian/idc") \
.set_volume_mount("/ppml/notebook/nfs", "source-code") \
.set_volume_nfs("nfs-data", "172.168.0.205", "/mnt/sdb/disk1/nfsdata/guancheng/hf") \
.set_volume_mount("/root/.cache", "nfs-data") \
.set_volume_nfs("nfs-model", "172.168.0.205", "/mnt/sdb/disk1/nfsdata/guancheng/model/chinese-pert-base") \
.set_volume_mount("/ppml/model", "nfs-model") \

k8s_args = k8s_conf.conf_to_args()


k8s_deployment.run_k8s(k8s_args)

In [None]:
# 1.define model
# 2.define dataset and dataloader

In [None]:
# 1. init ppml_context
#     k8s or local
# 2. init model
# 3. init dataset and dataloader
# 4. 
# import ppml_conf
# myconf = ppml_conf.PPMLConf()
# myconf.test()
# myconf.conf_to_args()