In [None]:
%pip install -q transformers
%pip install -q git+https://github.com/cthiounn/dalle-tiny.git
%pip install -q wandb
%pip install -q hivemind
%pip install --upgrade torch

In [None]:
WANDB_KEY=""
S3_BUCKET=""
CUSTOM_SAVE_FILE_NAME=""

In [None]:
# !wandb login $WANDB_KEY

In [None]:
# import wandb

# # wandb.init(project="my-test-project", entity="cthiounn",id="1piyh3bl",resume="must")
# wandb.init(project="my-test-project", entity="cthiounn")
# wandb.config = {
#   "learning_rate": 5e-5,
#   "epochs": 200,
#   "batch_size": 5
# }

In [None]:
from tqdm import tqdm
import s3fs
import os

S3_ENDPOINT_URL = "https://" + os.environ["AWS_S3_ENDPOINT"]
fs = s3fs.S3FileSystem(client_kwargs={'endpoint_url': S3_ENDPOINT_URL})

def write_file_to_s3(bucket_name:str,dir_file:str,file_name:str,fs:s3fs.core.S3FileSystem):
    if bucket_name and file_name and fs:
        FILE_PATH_OUT_S3 = bucket_name + "/" + file_name
        with fs.open(FILE_PATH_OUT_S3, 'wb') as file_out , open(dir_file+file_name, 'rb') as file_in:
            file_out.write(file_in.read())
            
files=['config.json','pytorch_model.bin']
for file in tqdm(files):
    with fs.open(f'{S3_BUCKET}/{file}', mode="rb") as file_in, open(file,"wb") as file_out:
            file_out.write(file_in.read())


In [None]:
from dalle_tiny.model import TinyDalleModel
from dalle_tiny.util import TinyDalleDataset
import torch
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader

device = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.cuda.manual_seed_all(42)
torch.manual_seed(42)

training_data = TinyDalleDataset(parquet_file="https://github.com/cthiounn/dalle-tiny/raw/main/archive_train.parquet",dataset_type="train")
test_data = TinyDalleDataset(parquet_file="https://github.com/cthiounn/dalle-tiny/raw/main/archive_val.parquet",dataset_type="val")

train_dataloader = DataLoader(training_data, batch_size=5, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=5)


In [None]:
from tqdm import tqdm
from transformers import BartForConditionalGeneration
import gc

import torch.nn as nn

torch.backends.cudnn.benchmark = True

# def freeze_params(model):
#     for par in model.parameters():
#         par.requires_grad = False

#del model
gc.collect()
torch.cuda.empty_cache()

try :
    model=TinyDalleModel.from_pretrained('.')
except:
    model=TinyDalleModel.from_pretrained('facebook/bart-large-cnn')

model.reinit_model_for_images()
model=model.to(device)
# freeze_params(model.get_encoder())
model.train()
optimizer = optim.AdamW(model.parameters(), betas=(0.9, 0.999), eps=1e-8, weight_decay=0.01, lr=5e-5)

loss_function = nn.CrossEntropyLoss()

def loss_fn(logits, labels):
    batch_size=logits.shape[0]
    seq_size=logits.shape[1]
    embed_size=logits.shape[2]
    return loss_function(logits.reshape((batch_size*seq_size,embed_size)), labels.reshape(batch_size*seq_size))

In [None]:
import hivemind
dht = hivemind.DHT(start=True)
print("To join the training, use initial_peers =", [str(addr) for addr in dht.get_visible_maddrs()])

In [None]:
# Set up a decentralized optimizer that will average with peers in background
opt = hivemind.Optimizer(
    dht=dht,                  # use a DHT that is connected with other peers
    run_id='tinydalle_run',    # unique identifier of this collaborative run
    batch_size_per_step=5,   # each call to opt.step adds this many samples towards the next epoch
    target_batch_size=10000,  # after peers collectively process this many samples, average weights and begin the next epoch 
    optimizer=optimizer,            # wrap the SGD optimizer defined above
    use_local_updates=True,   # perform optimizer steps with local gradients, average parameters in background
    matchmaking_time=3.0,     # when averaging parameters, gather peers in background for up to this many seconds
    averaging_timeout=10.0,   # give up on averaging if not successful in this many seconds
    verbose=True              # print logs incessently
)

In [None]:
num_batches_test = len(test_dataloader)

# wandb.watch(model)

i=0
while True:
    for batch in tqdm(train_dataloader):
        opt.zero_grad()
        caption,label =batch
        inp=caption.to(device)
        lab=label[0].to(device)
        shifted_label=model.prepare_decoder_input_ids_from_labels(lab).to(device)
        predict=model(input_ids=inp, decoder_input_ids =shifted_label)
        loss = loss_fn(predict.logits,lab)
        loss.backward()
        opt.step() 
#         i+=1
#         if i%100==0:
#             wandb.log({"train_loss": loss.item()})
#             print(f"train_loss:{loss.item()}")
#             test_loss=0
#             with torch.no_grad():
#                 for batch in tqdm(test_dataloader): 
#                     caption,label =batch
#                     inp=caption.to(device)
#                     lab=label[0].to(device)
#                     shifted_label=model.prepare_decoder_input_ids_from_labels(lab).to(device)
#                     predict=model(input_ids=inp, decoder_input_ids =shifted_label)
#                     loss = loss_fn(predict.logits, lab)
#                     test_loss += loss.item()
#                     del inp, lab, predict, loss, shifted_label
#                     torch.cuda.empty_cache()

#                 mean_test_loss=test_loss/num_batches_test
#                 wandb.log({"mean test_loss": mean_test_loss})
#                 print(f"mean test loss:{mean_test_loss}")

#                 file_name=f"../../checkpoint_{CUSTOM_SAVE_FILE_NAME}_{epoch}.pth"
#                 torch.save(model.state_dict(),file_name)
#                 try:
#                     write_file_to_s3(S3_BUCKET,"../../",f"checkpoint_{CUSTOM_SAVE_FILE_NAME}_{epoch}.pth",fs)
#                 except:
#                     print(f"can't write {file_name}")
        
        
        del inp, lab, predict, loss, shifted_label
        torch.cuda.empty_cache()

        


    