In [1]:
import os
import argparse
import glob
import torch
import torch.nn.functional as F
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import (
    CPUOffload,
    MixedPrecision,
    ShardingStrategy,
    BackwardPrefetch,
)
from torch import  nn, zeros, float32, float16, cuda, set_float32_matmul_precision, load, argmax, save, tile
from torch.utils.data.distributed import DistributedSampler
from torch.optim.lr_scheduler import StepLR
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import RandomSampler,DataLoader

from tqdm import tqdm
from src.flamcon import Flamcon, LayerNorm
from src.distributed import init_distributed_device, world_info_from_env
from src.misc import save_checkpoint
from src.dataloader import WebVidDataset, RandomVideos


from vit_pytorch.vit import ViT
from vit_pytorch.extractor import Extractor

from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer
)

from einops import rearrange
from deepspeed.ops.adam import FusedAdam
from deepspeed.ops.adam import DeepSpeedCPUAdam

from labml_nn.sampling.nucleus import NucleusSampler
from labml_nn.sampling.temperature import TemperatureSampler
from labml.logger import Text


set_float32_matmul_precision('medium')

def tokenize(tokenizer,text):
    tokenizer.padding_side = "right"
    text =  tokenizer(
        text,
        max_length=512,
        padding="longest",
        truncation="only_first",
        return_tensors="pt",
    )   
    return  text['input_ids'], text['attention_mask'].bool()

def getLoss(predicted, labels, logits, tokenizer):
    """
    Compute cross-entropy loss between predicted and actual tokens.

    Args:
        predicted (Tensor): Predicted tokens.
        actual (Tensor): Actual tokens.
        logits (Callable): Function to compute logits.

    Returns:
        loss (Tensor): Cross-entropy loss.
    """
    predicted = logits(predicted)[:,-labels.shape[1]:,:]
    #predicted = rearrange(predicted, 'b n c -> b c n')
    print(predicted.shape)
    
    labels[labels == tokenizer.pad_token_id] = -100
    labels[labels == tokenizer.eos_token] = -100
    labels[labels == tokenizer.encode("<media>")[-1]] = -100
        
    loss_fct = nn.CrossEntropyLoss()
    loss = 0
    for i in range(len(predicted)):
        losses = loss_fct(
            predicted[i], labels[i]
        )
        loss+=losses
    loss=loss/len(predicted)
    return loss

def generate(
        model,
        text,
        tokenizer,
        to_logits,
        *,
        images=None,
        videos=None,
        embeds=None,
        gen=True,
        attention_mask=None,
        n_tokens = 100,
        n_samples = 1
        
    ):
        sampler = NucleusSampler(0.95, TemperatureSampler(1.))
        data = torch.tile(text[None, :], (1, 1))[0]
        logs = [[(text, Text.meta)] for _ in range(n_samples)]
        seq_len = len(text)
        for i in range(n_tokens):
            data = data[-seq_len:]
            output = model.forward(data,images=images,videos=videos,embeds=embeds,gen=gen,attention_mask=attention_mask)
            logits = to_logits(output)
            logits = logits[:, -1]
            res = sampler(logits)
            data = torch.cat([data, res.reshape(res.shape[0],1)], dim=1)
#             for j in range(1):
#                 logs[j] += [('' + tokenizer.decode(res.item()), Text.value)]
        return data
    
def test(args, model, rank, dialogue, media, tokenizer, to_logits):
    """
    Perform the training loop for one epoch.

    Args:
        args: Parsed command-line arguments.
        model: The Flamcon model.
        rank (int): Process rank.
        dialog: tokens
        media: image/video

    Returns:
        text_tokens: the output the prediction
    """
    input_ids, attention_mask = tokenize(tokenizer,dialogue)
    input_ids = input_ids.to(rank)
    media = media.to(rank)
    if args.video:
        text_tokens = generate(model,input_ids, tokenizer, to_logits, videos=media)
    else:
        text_tokens = model.generate(input_ids, tokenizer, to_logits, images=media)
    return text_tokens

[2023-09-16 13:57:35,887] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)


In [2]:
class Args():
    horovod=False
    dist_backend="nccl"
    dist_url="env://"
    no_set_device_rank=False
    cpu_offload=True
    batch=2
    dim=4544
    num_tokena=65027
    epochs=1
    fsdp=True
    video=True
    max_frames=40
    max_tokens=512
    lang_model="tiiuae/falcon-7b"
    run_name="flamcon"
    my_group=None
    delete_previous_checkpoint=True
    resume=True
args=Args()
args.local_rank, args.rank, args.world_size = world_info_from_env()
device_id = 0

In [3]:
    vit = ViT(
        image_size = 256,
        patch_size = 32,
        num_classes = 1000,
        dim = args.dim,
        depth = 6,
        heads = 16,
        mlp_dim = 2048,
        dropout = 0.1,
        emb_dropout = 0.1
    )
    vit = Extractor(vit, return_embeddings_only = True).to('cuda')

In [4]:
    if args.rank == 0:
        print("Loading Falcon\n")
        
    #Loads language model
    falcon = AutoModelForCausalLM.from_pretrained(
                args.lang_model,
                trust_remote_code=True
            )
    tokenizer = AutoTokenizer.from_pretrained(args.lang_model)
    tokenizer.add_special_tokens(
        {"additional_special_tokens": ["<|endofchunk|>", "<media>"]}
    )
    tokenizer.add_special_tokens({'pad_token': '<PAD>'})
    falcon.resize_token_embeddings(new_num_tokens=len(tokenizer))
    
    #set Mixes precision policy
    mp_policy = MixedPrecision(
        param_dtype=float32,
        reduce_dtype=float16,  # gradient communication
        buffer_dtype=float16,
    )

    if args.rank == 0:
        print("Loading Flamcon")
    
    #Print parameters per GPU
    print(f"ViT parameter num: {sum(p.numel() for p in vit.parameters())} on rank {args.rank}\n")
    print(f"Language parameter num: {sum(p.numel() for p in falcon.parameters())} on rank {args.rank}\n")
    to_logits = falcon.lm_head.to(args.rank)  

Loading Falcon



Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

You are resizing the embedding layer without providing a `pad_to_multiple_of` parameter. This means that the new embedding dimension will be 65027. This might induce some performance reduction as *Tensor Cores* will not be available. For more details about this, or help on choosing the correct value for resizing, refer to this guide: https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc


Loading Flamcon
ViT parameter num: 242355496 on rank 0

Language parameter num: 6921734336 on rank 0



In [5]:
    model = Flamcon(
                    num_tokens = len(tokenizer),       # number of tokens
                    dim = args.dim,                     # dimensions
                    depth = 32,                         # depth
                    heads = 8,                          # attention heads
                    dim_head = 64,                      # dimension per attention head
                    img_encoder = vit,                  # plugin your image encoder (this can be optional if you pass in the image embeddings 
                    media_token_id = tokenizer.encode("<media>")[-1],                 # the token id representing the [media] or [image]
                    cross_attn_every = 3,               # how often to cross attend
                    perceiver_num_latents = 64,         # perceiver number of latents, should be smaller than the sequence length of the image tokens
                    perceiver_depth = 2,                # perceiver resampler depth
                    max_video_frames = args.max_frames, # max video frames
                    lang_model = falcon                 # llm
                    )

    del vit
    del falcon
    resume_from_epoch = 0
    checkpoint = None
    if os.path.exists(f"{args.run_name}") and args.resume:
        checkpoint_list = glob.glob(f"{args.run_name}/checkpoint_*.pt")
        if len(checkpoint_list) == 0:
            print(f"Found no checkpoints for run {args.run_name}.\n")
        else:
            resume_from_checkpoint = sorted(
                checkpoint_list, key=lambda x: int(x.split("_")[-1].split(".")[0])
            )[-1]
            print(
                f"Found checkpoint {resume_from_checkpoint} for run {args.run_name}.\n"
            )
            checkpoint = load(resume_from_checkpoint, map_location="cpu")
            resume_from_epoch = checkpoint["epoch"] + 1
            if args.rank == 0:
                model.load_state_dict(checkpoint["model_state_dict"], False)
    print('Model Loaded')

Found checkpoint flamcon/checkpoint_1.pt for run flamcon.

Model Loaded


In [6]:
model = model.to(args.rank)

In [7]:
data = WebVidDataset("test_nw.csv","data",args.max_frames,tokenizer,args.max_tokens,test=True,samples=48)
sampler = DistributedSampler(data, rank=args.rank, num_replicas=args.world_size, shuffle=True)
dataloader = DataLoader(data,batch_size=6,sampler=sampler)

In [8]:
for batchid, data in enumerate(tqdm(dataloader,position=0, desc="Iters", leave=False, colour='green', ncols=80)):
    media, X, y, file, = data
    media = media.to('cuda')
    text = test(args, model, args.rank, X, media, tokenizer, to_logits)
    print(file[0],text)
    print('\n')
    break

                                                                                ?, ?it/s][0m

ValueError: Expected parameter logits (Tensor of shape (6, 65027)) of distribution Categorical(logits: torch.Size([6, 65027])) to satisfy the constraint IndependentConstraint(Real(), 1), but found invalid values:
tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0')

In [26]:
    predicted = to_logits(text_tokens)[:,-input_ids_test.shape[1]:,:]
    #predicted = rearrange(predicted, 'b n c -> b c n')    
    input_ids_test[input_ids_test == tokenizer.pad_token_id] = -100
    input_ids_test[input_ids_test == tokenizer.eos_token] = -100
    input_ids_test[input_ids_test == tokenizer.encode("<media>")[-1]] = -100
        
    loss_fct = nn.CrossEntropyLoss()
    losses = 0
    for i in range(len(predicted)):
        loss = loss_fct(
            predicted[i], input_ids_test[i]
        )
        losses+=loss
    losses=losses/len(predicted)
    print(losses)

tensor(35.7711, device='cuda:0')


In [30]:
sampler(predicted)

tensor([[ 1003,  1003],
        [ 8801, 31920],
        [11060, 14560],
        [ 8801, 31920],
        [ 8801,  8801],
        [31920, 31920]], device='cuda:0')

In [47]:
tokenizer.batch_decode(input_ids[:,-input_ids_test.shape[1]:])

['<PAD><PAD>',
 '<PAD><PAD>',
 'woman sitting',
 '<PAD><PAD>',
 '<PAD><PAD>',
 '<PAD><PAD>']

In [45]:
tokenizer.batch_decode(sampler(predicted))

[' beingcurrent',
 'beingdoing',
 '© outdoors',
 'doingdoing',
 'beingbeing',
 'doingdoing']

In [24]:
input_ids, attention_mask = tokenize(tokenizer,X)
input_ids = input_ids.to(args.rank)
input_ids_test, attention_mask = tokenize(tokenizer,y)
input_ids_test = input_ids_test.to(args.rank)
sampler = NucleusSampler(0.95, TemperatureSampler(1.))
input_ids = torch.tile(input_ids[None, :], (1, 1))[0]
logs = [[(X, Text.meta)] for _ in range(1)]
seq_len = len(X)
for i in range(25):
    text_tokens = model.forward(input_ids,videos=media,gen=True,attention_mask=attention_mask)
    loss = getLoss(text_tokens, input_ids_test, to_logits, tokenizer)
    print(loss)
#     logits = to_logits(output)
#     logits = logits[:, -1, :]
#     res = sampler(logits)
#     #print(tokenizer.batch_decode(res,skip_special_tokens=True))
#     input_ids = torch.cat([input_ids, res], dim=1)
    #for j in range(1):
        #logs[j] += [('' + tokenizer.decode(res.item()), Text.value)]

torch.Size([6, 2, 65027])
tensor(35.7711, device='cuda:0')
torch.Size([6, 2, 65027])
tensor(35.7711, device='cuda:0')
torch.Size([6, 2, 65027])
tensor(35.7711, device='cuda:0')
torch.Size([6, 2, 65027])
tensor(35.7711, device='cuda:0')
torch.Size([6, 2, 65027])
tensor(35.7711, device='cuda:0')
torch.Size([6, 2, 65027])
tensor(35.7711, device='cuda:0')
torch.Size([6, 2, 65027])
tensor(35.7711, device='cuda:0')
torch.Size([6, 2, 65027])
tensor(35.7711, device='cuda:0')
torch.Size([6, 2, 65027])
tensor(35.7711, device='cuda:0')
torch.Size([6, 2, 65027])
tensor(35.7711, device='cuda:0')
torch.Size([6, 2, 65027])
tensor(35.7711, device='cuda:0')
torch.Size([6, 2, 65027])
tensor(35.7711, device='cuda:0')
torch.Size([6, 2, 65027])
tensor(35.7711, device='cuda:0')
torch.Size([6, 2, 65027])
tensor(35.7711, device='cuda:0')
torch.Size([6, 2, 65027])
tensor(35.7711, device='cuda:0')
torch.Size([6, 2, 65027])
tensor(35.7711, device='cuda:0')
torch.Size([6, 2, 65027])
tensor(35.7711, device='cuda:0

In [45]:
import pandas as pd
import re

In [8]:
def prepWebVid(data):
    rows = []
    for idx,row in data.iterrows():
        text = re.findall(r"[\w']+|[.,!?;]",row["name"]) 
        for i in range(len(text)-1):
            rw = row.copy()
            rw['X'] = ' '.join(text[:i+1])
            rw['y'] = text[i+1]
            rows.append(rw)
    rows = pd.DataFrame(rows)
    rows.reset_index()
    return rows

In [10]:
train = pd.read_csv('data/train.csv')
train = train.drop('Unnamed: 0',axis=1)
test = pd.read_csv('data/test.csv')
test = test.drop('Unnamed: 0',axis=1)

In [11]:
train = prepWebVid(train)
train.to_csv('data/train_nw.csv')

In [12]:
test = prepWebVid(test)
test.to_csv('data/test_nw.csv')

In [63]:
train_df = pd.read_csv('data/train_nw.csv')

In [66]:
train_df[train_df['videoid'].isin(train_df['videoid'].unique()[:500])]

Unnamed: 0.1,Unnamed: 0,videoid,contentUrl,duration,page_dir,name,X,y
0,0,26431865,https://ak.picdn.net/shutterstock/videos/26431...,PT00H00M10S,143201_143250,In a modern private office young male and fema...,In,a
1,0,26431865,https://ak.picdn.net/shutterstock/videos/26431...,PT00H00M10S,143201_143250,In a modern private office young male and fema...,In a,modern
2,0,26431865,https://ak.picdn.net/shutterstock/videos/26431...,PT00H00M10S,143201_143250,In a modern private office young male and fema...,In a modern,private
3,0,26431865,https://ak.picdn.net/shutterstock/videos/26431...,PT00H00M10S,143201_143250,In a modern private office young male and fema...,In a modern private,office
4,0,26431865,https://ak.picdn.net/shutterstock/videos/26431...,PT00H00M10S,143201_143250,In a modern private office young male and fema...,In a modern private office,young
...,...,...,...,...,...,...,...,...
9779,499,1012328408,https://ak.picdn.net/shutterstock/videos/10123...,PT00H00M11S,012451_012500,"Donbass ukraine 2018 ruins of the kiosk, broke...","Donbass ukraine 2018 ruins of the kiosk , broken",windows
9780,499,1012328408,https://ak.picdn.net/shutterstock/videos/10123...,PT00H00M11S,012451_012500,"Donbass ukraine 2018 ruins of the kiosk, broke...","Donbass ukraine 2018 ruins of the kiosk , brok...",","
9781,499,1012328408,https://ak.picdn.net/shutterstock/videos/10123...,PT00H00M11S,012451_012500,"Donbass ukraine 2018 ruins of the kiosk, broke...","Donbass ukraine 2018 ruins of the kiosk , brok...",abandoned
9782,499,1012328408,https://ak.picdn.net/shutterstock/videos/10123...,PT00H00M11S,012451_012500,"Donbass ukraine 2018 ruins of the kiosk, broke...","Donbass ukraine 2018 ruins of the kiosk , brok...",building


In [47]:
test_df = pd.read_csv('data/test_nw.csv')

In [62]:
test_df[test_df['videoid'].isin(test_df['videoid'].unique()[:50])]

Unnamed: 0.1,Unnamed: 0,videoid,contentUrl,duration,page_dir,name,X,y
0,0,1032280022,https://ak.picdn.net/shutterstock/videos/10322...,PT00H00M32S,070301_070350,Young man and woman working out outdoors in th...,Young,man
1,0,1032280022,https://ak.picdn.net/shutterstock/videos/10322...,PT00H00M32S,070301_070350,Young man and woman working out outdoors in th...,Young man,and
2,0,1032280022,https://ak.picdn.net/shutterstock/videos/10322...,PT00H00M32S,070301_070350,Young man and woman working out outdoors in th...,Young man and,woman
3,0,1032280022,https://ak.picdn.net/shutterstock/videos/10322...,PT00H00M32S,070301_070350,Young man and woman working out outdoors in th...,Young man and woman,working
4,0,1032280022,https://ak.picdn.net/shutterstock/videos/10322...,PT00H00M32S,070301_070350,Young man and woman working out outdoors in th...,Young man and woman working,out
...,...,...,...,...,...,...,...,...
987,49,1019991253,https://ak.picdn.net/shutterstock/videos/10199...,PT00H00M11S,078651_078700,"Venice, italy, 1974, palazzo, ca' rezonica, in...","Venice , italy , 1974 , palazzo , ca' rezonica...",","
988,49,1019991253,https://ak.picdn.net/shutterstock/videos/10199...,PT00H00M11S,078651_078700,"Venice, italy, 1974, palazzo, ca' rezonica, in...","Venice , italy , 1974 , palazzo , ca' rezonica...",ceiling
989,49,1019991253,https://ak.picdn.net/shutterstock/videos/10199...,PT00H00M11S,078651_078700,"Venice, italy, 1974, palazzo, ca' rezonica, in...","Venice , italy , 1974 , palazzo , ca' rezonica...",painting
990,49,1019991253,https://ak.picdn.net/shutterstock/videos/10199...,PT00H00M11S,078651_078700,"Venice, italy, 1974, palazzo, ca' rezonica, in...","Venice , italy , 1974 , palazzo , ca' rezonica...",","
