In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [1]:
import os

#현재 폴더 경로; 작업 폴더 기준
print(os.getcwd())

/content


In [2]:
os.chdir("/content/drive/MyDrive/chart2text/PALI3")

In [3]:
print(os.getcwd())

/content/drive/MyDrive/chart2text/PALI3


In [6]:
!pip install einops pytorch_model_summary




In [7]:
!pip install zetascale

Collecting argparse<2.0.0,>=1.4.0 (from zetascale)
  Using cached argparse-1.4.0-py2.py3-none-any.whl (23 kB)
Installing collected packages: argparse
Successfully installed argparse-1.4.0


In [4]:
from CustomPALI3.SummaryChartDataset import SummaryChartDataset
from CustomPALI3.CustomPALI3 import CustomPALI3Config,CustomPALI3
from transformers import T5Tokenizer
import torch
from torch.cuda.amp import autocast
from torch.cuda.amp import GradScaler
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader
import os
from pytorch_model_summary import summary
import numpy as np
import gc

dataset_repo=[{'dataset':'timm/imagenet-12k-wds','config':'default','type':'vision-text'},
                {'dataset':'wikimedia/wikipedia','config':'20231101.en','type':'text'},
                {'dataset':'conceptual_captions','config':'labeled','type':'vision-text'},
                {'dataset':'poloclub/diffusiondb','config':'2m_random_1m','type':'vision-text'},
                ]

def pretrain(
    model,
    args,
    train_loader,
    val_loader,
    optimizer,
    device,
    scheduler,

):
    model_path=args['output_dir']
    scaler = GradScaler()
    best_val_loss = float("inf")
    for epoch in range(int(args['num_epochs'])):
        model.model.train()
        step_num=0
        for _ in range(int(args['max_steps'])):
            try:
                gc.collect()
                torch.cuda.empty_cache()
                input_data = next(iter(train_loader))
                images=input_data['image']
                input_ids=input_data['input_ids']
                attn_masks=input_data['attn_mask']
                # print(len(input_ids))
                for sub_step in range(len(input_ids)):
                    # print(model.pali_model.parameters())
                    image = images[sub_step].to(device)
                    prompt=input_ids[sub_step].to(device)
                    output=input_ids[sub_step].to(device)
                    attn_mask=attn_masks[sub_step].to(device)
                    print(image.shape,prompt.shape,output.shape,attn_mask.shape)

                    # print(images[sub_step].shape,input_ids[sub_step].shape,attn_masks[sub_step].shape)
                    optimizer.zero_grad()
                    prev_dec_some_weight=model.model.pali_model.decoder.net.attn_layers.layers[0][1].to_out.weight[0,0].item()
                    prev_enc_some_weight=model.model.pali_model.encoder.attn_layers.layers[1][1].ff[0][0].weight[0,0].item()
                    prev_vit_some_weight=model.model.vit_model.model.model.vision_model.encoder.layers[-1].mlp.fc1.weight[0,0].item()
                    with autocast(dtype=torch.bfloat16):
                        logits, loss = model(img=image,prompt=prompt,output=output,mask=attn_mask)
                    scaler.scale(loss).backward()
                    scaler.step(optimizer)
                    scaler.update()
                    scheduler.step()

                    # logits, loss = model(img=image,prompt=prompt,output=output,mask=attn_mask)
                    # loss.backward()
                    # optimizer.step()
                    # scheduler.step()
                    prev_dec_some_weight2=model.model.pali_model.decoder.net.attn_layers.layers[0][1].to_out.weight[0,0].item()
                    prev_enc_some_weight2=model.model.pali_model.encoder.attn_layers.layers[1][1].ff[0][0].weight[0,0].item()
                    prev_vit_some_weight2=model.model.vit_model.model.model.vision_model.encoder.layers[-1].mlp.fc1.weight[0,0].item()
                    print('diff_vit_enc:',prev_vit_some_weight2-prev_vit_some_weight,'diff_text_enc:',prev_enc_some_weight2-prev_enc_some_weight,'diff_text_dec',prev_dec_some_weight2-prev_dec_some_weight)
                    print(f"Epoch: {epoch+1}, Step: {step_num+1}, Train Loss: {loss}")
                    step_num+=1
                scheduler.step()
                if step_num%100==0 and step_num!=0:
                    save_checkpoint(model,model_path+'_temp')
            except Exception as e:
                print('occurs error : ',e)
                continue

        val_loss = validate(model, val_loader, device,args)

        print(f"Epoch: {epoch+1}, Train Loss: {loss}, Val Loss: {val_loss}")
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            save_checkpoint(model,model_path)

def validate(model, dataloader, device,args):
    model.model.eval()
    total_loss = 0
    with torch.no_grad():
        for _ in range(int(args['valid_steps'])):
            try:
                input_data = next(iter(dataloader))
                images=input_data['image']
                input_ids=input_data['input_ids']
                attn_masks=input_data['attn_mask']
                for sub_step in range(len(input_ids)):
                    image = images[sub_step].to(device)
                    prompt=input_ids[sub_step].to(device)
                    output=input_ids[sub_step].to(device)
                    attn_mask=attn_masks[sub_step].to(device)
                    logits, loss = model(img=image,prompt=prompt,output=output,mask=attn_mask)
                    total_loss += loss
            except:
                continue
    return total_loss / len(dataloader)

def save_checkpoint(model,save_path):
    model.save_pretrained(save_path, from_pt=True)

def my_collate_fn(samples):
    image_batch = []
    input_batch = []
    attn_mask_batch = []

    batch_size=len(samples)

    image_batch_ = []
    input_batch_ = []
    attn_mask_batch_ = []
    for sample in samples:
        image_batch_.extend(sample['image'])
        input_batch_.extend(sample['input_ids'])
        attn_mask_batch_.extend(sample['attn_mask'])

    total_b=len(image_batch_)//batch_size # 14 //4   3.xx 3
    total_b=total_b+1 if len(image_batch_)%batch_size!=0  else total_b
    for i in range(total_b):
        if (i+1)*batch_size<len(image_batch_):
            image_batch.append(torch.stack(image_batch_[i*batch_size:(i+1)*batch_size]))
            input_batch.append(torch.stack(input_batch_[i*batch_size:(i+1)*batch_size]))
            attn_mask_batch.append(torch.stack(attn_mask_batch_[i*batch_size:(i+1)*batch_size]))
        else:
            image_batch.append(torch.stack(image_batch_[i*batch_size:]))
            input_batch.append(torch.stack(input_batch_[i*batch_size:]))
            attn_mask_batch.append(torch.stack(attn_mask_batch_[i*batch_size:]))

    return {'image': image_batch, 'input_ids': input_batch,'attn_mask':attn_mask_batch}
def getParameters(model):
    count=0
    def mul(list_):
            init=1
            for i in list_:
                    init*=i
            return init
    for name, param in model.vit_model.named_parameters():
            count+=mul(np.array(param.size()).tolist())
    print('vit_model',count/1000000,"M")
    count=0
    for name, param in model.pali_model.named_parameters():
            count+=mul(np.array(param.size()).tolist())
    print('pali_model',count/1000000,"M")


  _torch_pytree._register_pytree_node(
[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(


In [None]:

if __name__=='__main__':
    args={
        'output_dir':'/content/drive/MyDrive/chart2text/PALI3/output',
        'lr':1e-4,
        'max_steps':1e4,
        'valid_steps':1e2,
        'num_epochs':100,
        'batch_size':4,
        'num_training_samples_per_epoch':10,
        'max_epochs':100,
        "warmup_steps":100,
        'num_workers':1,
        'num_nodes':1,
        }

    tokenizer=T5Tokenizer.from_pretrained("google/flan-t5-base", bos_token = '<s>',add_bos_token = True)
    train_loader=DataLoader(SummaryChartDataset(dataset_repo,1024,tokenizer,'</s>','train'), batch_size=args['batch_size'], shuffle=True, num_workers=1,collate_fn=my_collate_fn)
    val_loader=DataLoader(SummaryChartDataset(dataset_repo,1024,tokenizer,'</s>','validation'), batch_size=args['batch_size'], shuffle=True, num_workers=1,collate_fn=my_collate_fn)
    config=CustomPALI3Config(version=1,model_name='test',
                        dim=1024,enc_num_tokens=32100,enc_max_seq_len=1024,
                        dec_num_tokens=32100,dec_max_seq_len=1024,enc_depth=12,enc_heads=18,dec_depth=12,dec_heads=18,seq_len=1024
                        ,device='cuda',vit_fix=False)

    # device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    model=CustomPALI3(config)
    print(device)
    # model=model.from_pretrained("/Users/dongunyun/study/datascience/chart2text/PALI3/output_temp")
    optimizer = torch.optim.Adam(model.parameters(), lr=args['lr'])
    scheduler = StepLR(optimizer, step_size=500, gamma=0.1)

    summary(model,torch.zeros((1,3,336,336)).to(device=device,dtype=torch.long),
                            torch.zeros((1,1024)).to(device=device,dtype=torch.long),
                            torch.zeros((1,1024)).to(device=device,dtype=torch.long),
                            torch.ones(1, 1024).bool().to(device=device),
                            show_input=True, print_summary=True,)
    getParameters(model.model)
    pretrain(
        model,
        args,
        train_loader,
        val_loader,
        optimizer,
        device,
        scheduler,
    )

cuda
---------------------------------------------------------------------------------------------------------
      Layer (type)                                           Input Shape         Param #     Tr. Param #
    CustomPALI3_-1     [1, 3, 336, 336], [1, 1024], [1, 1024], [1, 1024]     809,645,768     809,645,768
Total params: 809,645,768
Trainable params: 809,645,768
Non-trainable params: 0
---------------------------------------------------------------------------------------------------------
vit_model 304.557056 M
pali_model 505.088712 M
torch.Size([4, 3, 336, 336]) torch.Size([4, 1024]) torch.Size([4, 1024]) torch.Size([4, 1024])
occurs error :  CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

occurs error :  CUDA error: device-side assert tri

Exception ignored in: <function _xla_gc_callback at 0x7c3095898f70>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/lib/__init__.py", line 97, in _xla_gc_callback
    def _xla_gc_callback(*args):
KeyboardInterrupt: 


occurs error :  CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

occurs error :  CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

occurs error :  CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

occurs error :  CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronous