In [None]:
%pip install -q transformers
%pip install -q git+https://github.com/cthiounn/dalle-tiny.git
%pip install -q wandb
%pip install --upgrade torch

In [None]:
import os
from datetime import date
from tqdm import tqdm
import gc
import s3fs
from transformers import BartForConditionalGeneration
from dalle_tiny.model import TinyDalleModel
from dalle_tiny.util import TinyDalleDataset
import torch
import torch.optim as optim
import torch.nn.functional as F
import torch.nn as nn
import torch.utils.data as data_utils
from torch.utils.data import DataLoader
import wandb

# defaults params
today = date.today()
d1 = today.strftime("%Y_%m_%d")
S3_ENDPOINT_URL = "https://" + os.environ["AWS_S3_ENDPOINT"]

# Parameter to edit
LOAD_FILE_FROM_S3=False
WANDB_KEY=""
WANDB_PROJECT_NAME="dalle-tiny"
WANDB_RUN_NAME=""
S3_BUCKET="cthiounn2"
CUSTOM_SAVE_FILE_NAME="test_"+d1
DEFAULT_SEED=42
PAYSAGE_DATASET=False

ADDITIONAL_SAVEFILE_FROM_S3="checkpoint_paysages_bis2022_04_20_26000.pth"
LIST_FILE_FROM_S3=['config.json','pytorch_model.bin']
if PAYSAGE_DATASET:
    LIST_FILE_FROM_S3.extend(['paysage_train2017_caption_image.parquet','paysage_val2017_caption_image.parquet'])
else:
    LIST_FILE_FROM_S3.extend(['train2017_caption_image.parquet','val2017_caption_image.parquet'])
if ADDITIONAL_SAVEFILE_FROM_S3:
    LIST_FILE_FROM_S3.append(ADDITIONAL_SAVEFILE_FROM_S3)

config = {
  "learning_rate": 4e-5,
  "epochs": 200,
  "batch_size": 10
}

# default objects
fs = s3fs.S3FileSystem(client_kwargs={'endpoint_url': S3_ENDPOINT_URL})
device = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.cuda.manual_seed_all(DEFAULT_SEED)
torch.manual_seed(DEFAULT_SEED)



In [None]:
!wandb login $WANDB_KEY

In [None]:
if WANDB_RUN_NAME:
    wandb.init(project="dalle-tiny", entity="cthiounn",id=WANDB_RUN_NAME,resume="must")
else:
    wandb.init(project="dalle-tiny", entity="cthiounn",config=config)

In [None]:
def write_file_to_s3(bucket_name:str,dir_file:str,file_name:str,fs:s3fs.core.S3FileSystem):
    if bucket_name and file_name and fs:
        FILE_PATH_OUT_S3 = bucket_name + "/" + file_name
        with fs.open(FILE_PATH_OUT_S3, 'wb') as file_out , open(dir_file+file_name, 'rb') as file_in:
            file_out.write(file_in.read())

def load_files_from_s3(load,listfiles):
    if load:
        for file in tqdm(listfiles):
            with fs.open(f'{S3_BUCKET}/{file}', mode="rb") as file_in, open(file,"wb") as file_out:
                    file_out.write(file_in.read())

load_files_from_s3(LOAD_FILE_FROM_S3, LIST_FILE_FROM_S3)

In [None]:
if PAYSAGE_DATASET:
    training_data = TinyDalleDataset(parquet_file="paysage_train2017_caption_image.parquet",dataset_type="train")
    test_data = TinyDalleDataset(parquet_file="paysage_val2017_caption_image.parquet",dataset_type="val")
    indices_tr = torch.arange(12_000)
    indices_va = torch.arange(480)
else:
    training_data = TinyDalleDataset(parquet_file="train2017_caption_image.parquet",dataset_type="train")
    test_data = TinyDalleDataset(parquet_file="val2017_caption_image.parquet",dataset_type="val")
    indices_tr = torch.arange(118_000)
    indices_va = torch.arange(25_000)

subtrain_data = data_utils.Subset(training_data, indices_tr)
subval_data = data_utils.Subset(test_data, indices_va)
train_dataloader = DataLoader(subtrain_data, batch_size=config["batch_size"], shuffle=True,pin_memory=True)
test_dataloader = DataLoader(subval_data, batch_size=config["batch_size"],pin_memory=True)

In [None]:
def run_eval(model, test_dataloader, num_batches_test, wandb):
    model.eval()
    test_loss=0
    with torch.no_grad():
        for batch in tqdm(test_dataloader): 
            caption,label =batch
            inp=caption[0].to(device)
            lab=label[0].to(device)
            shifted_label = lab.new_zeros(lab.shape)
            shifted_label[:, 1:] = lab[:, :-1].clone()
            shifted_label[:, 0] = 16384
            shifted_label=shifted_label.to(device)
            predict=model(input_ids=inp, decoder_input_ids =shifted_label)
            loss = loss_fn(predict.logits, lab)
            test_loss += loss.item()
            del inp, lab, predict, loss, shifted_label
            torch.cuda.empty_cache()
            
        mean_test_loss=test_loss/num_batches_test
        wandb.log({"mean test_loss": mean_test_loss})
        print(f"mean test loss:{mean_test_loss}")

def save_model(CUSTOM_SAVE_FILE_NAME,i,model):
    file_name=f"../../checkpoint_{CUSTOM_SAVE_FILE_NAME}_{i}.pth"
    if torch.cuda.device_count() > 1:
        torch.save(model.module.state_dict(),file_name)
    else:
        torch.save(model.state_dict(),file_name)
    try:
        write_file_to_s3(S3_BUCKET,"../../",f"checkpoint_{CUSTOM_SAVE_FILE_NAME}_{i}.pth",fs)
    except:
        print(f"can't write {file_name}")

loss_function = nn.CrossEntropyLoss()

def loss_fn(logits, labels):
    batch_size=logits.shape[0]
    seq_size=logits.shape[1]
    embed_size=logits.shape[2]
    return loss_function(logits.reshape((batch_size*seq_size,embed_size)), labels.reshape(batch_size*seq_size))

def run_train(CUSTOM_SAVE_FILE_NAME, i, model, train_dataloader, optimizer, wandb):
    model.train()
    for batch in tqdm(train_dataloader):
        caption,label =batch
        inp=caption[0].to(device)
        lab=label[0].to(device)
        shifted_label = lab.new_zeros(lab.shape)
        shifted_label[:, 1:] = lab[:, :-1].clone()
        shifted_label[:, 0] = 16384
        shifted_label=shifted_label.to(device)
        predict=model(input_ids=inp, decoder_input_ids =shifted_label)            
        loss = loss_fn(predict.logits,lab)       
        i+=1
        if i%100==0:
            wandb.log({"train_loss": loss.item()})
            print(f"train_loss:{loss.item()}")
        loss.backward()  
        optimizer.step()
        optimizer.zero_grad()  
        del inp, lab, predict, loss, shifted_label
        torch.cuda.empty_cache()
        if i % 1000 == 0 :
            save_model(CUSTOM_SAVE_FILE_NAME,i,model)
    return i
            
def load_model(filename=None):
    try :
        model = TinyDalleModel.from_pretrained('.')
    except:
        model = TinyDalleModel.from_pretrained('facebook/bart-large-cnn')  
    
    model.reinit_model_for_images()
    if filename:
        try :
            model.load_state_dict(torch.load(filename,map_location=device))
        except :
            print(f"cant load model file {filename}")
    return model

In [None]:
torch.backends.cudnn.benchmark = True



# del model
gc.collect()
torch.cuda.empty_cache()


# model=load_model()
model=load_model("checkpoint_paysages_bis2022_04_20_26000.pth")

def freeze_params(model):
    for par in model.parameters():
        par.requires_grad = False

# freeze_params(model.get_encoder())
# freeze_params(model.get_decoder())
# model.lm_head=nn.Linear(in_features=1024, out_features=16384+1, bias=False)
# model.final_logits_bias=torch.rand(16384+1)    

if torch.cuda.device_count() > 1:
    model = nn.DataParallel(model)

model=model.to(device)
optimizer = optim.AdamW(model.parameters(), betas=(0.9, 0.99), eps=1e-8, weight_decay=0.01, lr=config["learning_rate"])
num_batches_test = len(test_dataloader)
wandb.watch(model)

# i=0
i=26000
for epoch in range(1,201):
    print(f"epoch={epoch}")
    i = run_train(CUSTOM_SAVE_FILE_NAME, i, model, train_dataloader, optimizer, wandb)
    run_eval(model, test_dataloader, num_batches_test, wandb)