In [1]:
import sys
sys.path
sys.path.append("/home/haoqi.whq/llm-inference/LoRA")

from src.model import GPT2Config, GPT2LMModel
import torch
from loralib import PruneLayer
import loralib as lora
import math
import importlib

## Load Data

In [2]:
from torch.utils.data import DataLoader
from src.data_utils import FT_Dataset

train_data = "./data/e2e/train.jsonl"
valid_data = "./data/e2e/valid.jsonl"
train_batch_size = 8
valid_batch_size = 4
seq_len = 512
obj = "clm"
random_seed = 888
label_smooth = 0.1

train_data = FT_Dataset(
    train_data, train_batch_size, seq_len, joint_lm=obj == "jlm"
)

valid_data = FT_Dataset(
    valid_data,
    valid_batch_size,
    seq_len,
)

train_loader = DataLoader(
    train_data,
    batch_size=train_batch_size,
    num_workers=0,
    shuffle=False,
    pin_memory=False,
    drop_last=True,
    # sampler=torch.utils.data.distributed.DistributedSampler(
    #     train_data, seed=random_seed
    # ),
)

valid_loader = DataLoader(
    valid_data,
    batch_size=valid_batch_size,
    num_workers=0,
    shuffle=False,
    pin_memory=False,
    drop_last=False,
    # sampler=torch.utils.data.distributed.DistributedSampler(
    #     valid_data, seed=random_seed
    # ),
)

print(type(train_loader))
# data = train_loader[0]

for i, data in enumerate(train_loader):
    data = {key: value for key, value in data.items()}
    _input = data["input"]
    _target = data["target"]
    _msk = data["mask"]

    print(_input.shape)
    print(_target.shape)
    print(_msk.shape)
    break

<class 'torch.utils.data.dataloader.DataLoader'>
torch.Size([8, 512])
torch.Size([8, 512])
torch.Size([8, 512])


## Define and Load Model

In [3]:
# medium GPT2
config = GPT2Config(
    n_embd=1024,
    n_layer=24,
    n_head=16,
    lora_attn_dim=4,
    lora_attn_alpha=32,
    lora_dropout=0.1,
    enable_mlp=True,
    enable_wo=True,
    enable_wq=True,
    enable_wk=True,
    enable_wv=True,
)

lm_net = GPT2LMModel(config)

# summary(lm_net, _input, lm_labels=_target, lm_mask=_msk, label_smooth=label_smooth, depth=6)
# print(lm_net.state_dict())
# transformer.h.0.attn.q_atten.lora_B
# lm_head.decoder.weight
# print("loading model pretrained weight.")
lm_net.load_lora_weight(
    "./pretrained_checkpoints/gpt2-medium-pytorch_model.bin", 
    "./tmp/retrain/qkvom/no_cp/model.ncp.lora.26000.pt"
)

O Attention using LoRA: PruneGPTConv1D(
  (lora_dropout): Dropout(p=0.1, inplace=False)
)
MLP using LoRA: PruneGPTConv1D(
  (lora_dropout): Dropout(p=0.1, inplace=False)
), PruneGPTConv1D(
  (lora_dropout): Dropout(p=0.1, inplace=False)
)


In [4]:
for m in lm_net.modules():
    if isinstance(m, PruneLayer):
        print(m, m.complexity(), m.lora_scaling)

PruneLinear(
  in_features=1024, out_features=1024, bias=True
  (lora_dropout): Dropout(p=0.1, inplace=False)
) tensor(51542.3086, grad_fn=<MulBackward0>) Parameter containing:
tensor(6.2918, requires_grad=True)
PruneLinear(
  in_features=1024, out_features=1024, bias=True
  (lora_dropout): Dropout(p=0.1, inplace=False)
) tensor(52160.0273, grad_fn=<MulBackward0>) Parameter containing:
tensor(6.3672, requires_grad=True)
PruneLinear(
  in_features=1024, out_features=1024, bias=True
  (lora_dropout): Dropout(p=0.1, inplace=False)
) tensor(52289.1562, grad_fn=<MulBackward0>) Parameter containing:
tensor(6.3830, requires_grad=True)
PruneGPTConv1D(
  (lora_dropout): Dropout(p=0.1, inplace=False)
) tensor(54491.7422, grad_fn=<MulBackward0>) Parameter containing:
tensor(6.6518, requires_grad=True)
PruneGPTConv1D(
  (lora_dropout): Dropout(p=0.1, inplace=False)
) tensor(129962.5156, grad_fn=<MulBackward0>) Parameter containing:
tensor(6.3458, requires_grad=True)
PruneGPTConv1D(
  (lora_dropout

In [5]:
lm_net.cuda()
lm_net = torch.nn.DataParallel(lm_net)

lora.mark_only_lora_as_trainable(lm_net)

## Inference

In [6]:
class AverageMeter(object):
    """Computes and stores the average and current value
    Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262
    """

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

## Test Prune LoRA

In [None]:
lora.prune_lora(lm_net, percent_prune=0.5)

In [10]:
lm_net.eval()
avg_lm_loss = AverageMeter()

with torch.no_grad():
    for idx, data in enumerate(valid_loader):
        data = {key: value for key, value in data.items()}

        _input = data["input"].cuda()
        _target = data["target"].cuda()
        _msk = data["mask"].cuda()

        _lm_logits, _loss = lm_net(_input, lm_labels=_target, lm_mask=_msk)
        loss = _loss.mean()

        avg_lm_loss.update(loss.item())

        if idx % 100 == 0:
            print("eval samples:", idx, "loss:", loss.float())

print(avg_lm_loss.avg, math.exp(avg_lm_loss.avg))



eval samples: 0 loss: tensor(1.8663, device='cuda:0')
eval samples: 100 loss: tensor(1.6873, device='cuda:0')
eval samples: 200 loss: tensor(2.0302, device='cuda:0')
eval samples: 300 loss: tensor(2.0830, device='cuda:0')
eval samples: 400 loss: tensor(2.3319, device='cuda:0')
eval samples: 500 loss: tensor(2.3907, device='cuda:0')
eval samples: 600 loss: tensor(2.0874, device='cuda:0')
eval samples: 700 loss: tensor(1.7890, device='cuda:0')
eval samples: 800 loss: tensor(2.1447, device='cuda:0')
eval samples: 900 loss: tensor(2.5808, device='cuda:0')
eval samples: 1000 loss: tensor(2.3787, device='cuda:0')
eval samples: 1100 loss: tensor(1.9162, device='cuda:0')
2.0083035004057295 7.450666566523994
