In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
#!unzip /content/drive/MyDrive/out_files.zip
#!unzip /content/drive/MyDrive/MindBigData-Imagenet-IN.zip

In [None]:
#@title Setup
!pip install omegaconf einops
from IPython.display import clear_output
clear_output()
!mkdir vit_vqgan
%cd vit_vqgan
!gdown 1DbHEBNzjefNfwG0AKvYKB64if5Usua2n
!gdown 1-9INRFzvxDlQxyLX3fGA9ZnL3HXYyeTf
!gdown 1HzNvpeqvUTHz9tQOiV6G2r5sHKKhKQUZ
!gdown 1syv0t3nAJ-bETFgFpztw9cPXghanUaM6
clear_output()

In [None]:
!gdown 1DbHEBNzjefNfwG0AKvYKB64if5Usua2n
!gdown 1-9INRFzvxDlQxyLX3fGA9ZnL3HXYyeTf
!gdown 1HzNvpeqvUTHz9tQOiV6G2r5sHKKhKQUZ
!gdown 1syv0t3nAJ-bETFgFpztw9cPXghanUaM6

In [None]:
#@title Imports
import sys,os
os.chdir("/content/vit_vqgan")
sys.path.append("/content/vit_vqgan")

import io
import re
import PIL
import fnmatch
import requests
import numpy as np
import pandas as pd
import torch
import torchvision
from torchvision.utils import save_image
import math
from torch.utils.data import DataLoader
import torchvision.transforms as T
import torchvision.transforms.functional as TF
from torchvision.utils import save_image
import gc
import cv2
import glob
from matplotlib import pyplot as plt
from IPython.display import clear_output

from vitvqgan import ViTVQ
from PIL import Image

from sklearn.utils import shuffle
from sklearn import metrics, model_selection
from sklearn.model_selection import train_test_split

import torch
torch.manual_seed(0)

In [None]:
#@title Functions
def show_ims(recon,original):
    fig = plt.figure(figsize=(10, 7))

    Image1 = np.array(original)
    Image2 = np.array(recon)
    fig.add_subplot(1, 2, 1)
      
    plt.imshow(Image1)
    plt.axis('off')
    plt.title("Original")
    fig.add_subplot(1,2,2)
      
    plt.imshow(Image2)
    plt.axis('off')
    plt.title("Reconstructed")

def download_image(url):
    resp = requests.get(url)
    resp.raise_for_status()
    return PIL.Image.open(io.BytesIO(resp.content))

def preprocess(img):        
    #r = 256 / s
    #s = (round(r * img.size[1]), round(r * img.size[0]))
    #img = TF.resize(img, s, interpolation=PIL.Image.LANCZOS)
    #img = TF.center_crop(img, output_size=2 * [256])
    img = img.resize((256,256))
    img_t = T.ToTensor()(img)
    return img, img_t

to_Pil=T.ToPILImage()

In [None]:
#@title Model

class PositionalEncoding(torch.nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = torch.nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        """
        Args:
            x: Tensor, shape [seq_len, batch_size, embedding_dim]
        """
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)

class TSParti(torch.nn.Module):
    """
    Simplest classifier/regressor. Can be either regressor or classifier because the output does not include
    softmax. Concatenates final layer embeddings and uses 0s to ignore padding embeddings in final output layer.
    """

    def __init__(self, ts_dim=5, ts_seq_len=360, nhead=4, num_encoder_layers=3, vae_path='./imagenet_vitvq_base.ckpt'):
        super(TSParti, self).__init__()

        self.d_model = 32
        im_seq_len = 1024
        n_embed = 8192
        
        self.nhead = nhead

        self.ts_projection = torch.nn.Linear(ts_dim, self.d_model)
        self.transformer_model = torch.nn.Transformer(nhead=nhead, num_encoder_layers=num_encoder_layers, batch_first=True, d_model=self.d_model)
        self.start_token = torch.nn.Parameter(torch.randn(1,1,self.d_model))
        #self.end_token = torch.nn.Parameter(torch.randn(1,1,self.d_model))
        self.image_token_embed = torch.nn.Embedding(n_embed, self.d_model)

        self.pos_encoder = PositionalEncoding(self.d_model, 0.1, max_len=360)
        self.axial_height_pos = torch.nn.Parameter(torch.randn(im_seq_len, self.d_model))
        self.axial_width_pos = torch.nn.Parameter(torch.randn(im_seq_len, self.d_model))

        self.to_logits = torch.nn.Linear(self.d_model, n_embed, bias = False)
        self.to_logits.weight = self.image_token_embed.weight

    def forward(self, src, tgt_codes, tgt_mask):
        #training only!
        #src: bs, ts_seq_len, ts_dim
        #tgt: bs, im_seq_len, im_dim
        bs = src.shape[0]
        
        st = self.start_token.repeat(bs, 1, 1)

        tgt_codes, labels = tgt_codes[:, :-1], tgt_codes
        tgt = self.image_token_embed(tgt_codes)
        tgt_seq_len = tgt.shape[1]

        axial_pos_emb = self.axial_width_pos.unsqueeze(0) + self.axial_height_pos.unsqueeze(1)
        axial_pos_emb = axial_pos_emb.reshape(axial_pos_emb.shape[0]*axial_pos_emb.shape[1], axial_pos_emb.shape[2])
        tgt = tgt + axial_pos_emb[:tgt_seq_len]
        tgt = torch.cat([st, tgt], 1)
        src = self.ts_projection(src) 
        enc_out = self.transformer_model.encoder(src)# * math.sqrt(self.d_model)
        enc_out = self.pos_encoder(enc_out)
        #enc_out = torch.cat([st, enc_out], 1)
        dec_out = self.transformer_model.decoder(tgt, enc_out)#, tgt_mask)
        logits = self.to_logits(dec_out)

        return logits

In [None]:
#@title Data
def trim_brain_ts(brain_ts, cutoff):
    if len(brain_ts) > cutoff:
        cutoff_row_n = (len(brain_ts) - cutoff) // 2
        brain_ts = brain_ts[cutoff_row_n:len(brain_ts)-cutoff_row_n]
    return brain_ts

class BrainwaveDataset(torch.utils.data.Dataset):
    def __init__(self, new_csvs, new_imgs, tensor_path):
        self.new_csvs = new_csvs
        self.new_imgs = new_imgs
        self.tss = torch.load(tensor_path)
        self.mean = self.tss.mean(0).mean(0)
        self.std = self.tss.std(0).std(0)
        self.tss = (self.tss - self.mean) / self.std

        self.tss = self.tss[:len(new_imgs)]

    def __len__(self):
        return len(self.new_imgs)

    def __getitem__(self, idx):
      _, img = preprocess(Image.open(self.new_imgs[idx]).convert('RGB'))
      #brain_ts = pd.read_csv(self.new_csvs[idx], index_col=0, header=None).T
      #brain_ts = (brain_ts - np.min(brain_ts, 0)) / (np.max(brain_ts, 0) - np.min(brain_ts, 0))
      src = self.tss[idx]
      #src = torch.Tensor(trim_brain_ts(brain_ts, CUTOFF).to_numpy())
      return {'src':src, 'img':img}

In [None]:
def show_im(im):
  save_image(im, 'im.png')
  return Image.open("im.png")

In [None]:
bs = 64
ts_seq_len = 360
ts_dim = 5
im_seq_len = 1024
im_dim = 32

CUTOFF = ts_seq_len

nhead = 4
num_encoder_layers = 3
epochs = 1000
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

tss_train_tensor_path = '/content/drive/MyDrive/TSS Tensors/train_tss.pt'
tss_test_tensor_path = '/content/drive/MyDrive/TSS Tensors/test_tss.pt'
codes_train_tensor_path = '/content/drive/MyDrive/TSS Tensors/train_codes.pt'
codes_test_tensor_path = '/content/drive/MyDrive/TSS Tensors/test_codes.pt'
new_imgs_csvs_labels = '/content/drive/MyDrive/TSS Tensors/new_imgs_csvs_labels.pt'

new_imgs, new_csvs, labels = torch.load(new_imgs_csvs_labels)

#new_imgs = [new_imgs[0]]

In [None]:
encoder = {'dim': 768, 'depth': 12, 'heads': 12, 'mlp_dim': 3072}
decoder = {'dim': 768, 'depth': 12, 'heads': 12, 'mlp_dim': 3072}
quantizer = {'embed_dim': 32, 'n_embed': 8192}
vae_path = './imagenet_vitvq_base.ckpt'
tsparti = TSParti(ts_dim=ts_dim, ts_seq_len=ts_seq_len, nhead=nhead, num_encoder_layers=num_encoder_layers).to(device)
vae = ViTVQ(256, 8, encoder, decoder, quantizer, path=vae_path).to(device)#.cuda()
for param in vae.parameters():
  param.requires_grad = False

In [None]:
#new_imgs_train, new_imgs_test, new_csvs_train, new_csvs_test = train_test_split(new_imgs, new_csvs, test_size=0.1, random_state=42)
new_imgs_train, new_imgs_test, new_csvs_train, new_csvs_test = new_imgs, new_imgs, new_csvs, new_csvs
ds_train = BrainwaveDataset(new_csvs_train, new_imgs_train, tss_train_tensor_path)
ds_test = BrainwaveDataset(new_csvs_test, new_imgs_test, tss_test_tensor_path)

train_dataloader = DataLoader(ds_train, batch_size=bs, shuffle=False)
test_dataloader = DataLoader(ds_test, batch_size=bs, shuffle=False)

In [None]:
all_codes = []
for i, samp in enumerate(train_dataloader):
    #src = samp['src'].to(device)
    img = samp['img'].to(device)
    #src = torch.ones_like(src).to(device)
    with torch.no_grad():
      tgt_codes = vae.encode_codes(img)
      all_codes.append(tgt_codes)

all_codes = torch.cat(all_codes, 0)
torch.save(all_codes, '/content/drive/MyDrive/TSS Tensors/train_codes.pt')

In [None]:
all_codes = []
for i, samp in enumerate(test_dataloader):
    #src = samp['src'].to(device)
    img = samp['img'].to(device)
    #src = torch.ones_like(src).to(device)
    with torch.no_grad():
      tgt_codes = vae.encode_codes(img)
      all_codes.append(tgt_codes)

all_codes = torch.cat(all_codes, 0)
torch.save(all_codes, '/content/drive/MyDrive/TSS Tensors/test_codes.pt')

In [None]:
#iterator = iter(train_dataloader)

In [None]:
optimizer = torch.optim.AdamW(tsparti.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', min_lr=1e-7, patience=100, threshold=0.001)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
tsparti = tsparti.to(device)
vae = vae.to(device)

for epoch in range(epochs):
  for i, samp in enumerate(train_dataloader):
    src = samp['src'].to(device)
    img = samp['img'].to(device)
    #src = torch.ones_like(src).to(device)
    with torch.no_grad():
      tgt_codes = vae.encode_codes(img)

    optimizer.zero_grad()
    tgt_mask = torch.triu(torch.ones(im_seq_len, im_seq_len) * float('-inf'), diagonal=1).repeat(bs*nhead, 1, 1).to(device)
    #tgt_mask = torch.ones(im_seq_len, im_seq_len).repeat(src.shape[0]*nhead, 1, 1).to(device)
    logits = tsparti(src, tgt_codes, tgt_mask).to(device)
    loss = torch.nn.functional.cross_entropy(logits.permute(0,2,1), tgt_codes)
    loss.backward()
    optimizer.step()
    scheduler.step(loss)
    lr = [el['lr'] for el in optimizer.param_groups][0]
    print(f'Epoch: {epoch}; Step: {i}; loss: {loss}; lr: {lr}')
# tgt - label; dec_out - preds
#with torch.no_grad():
#  vae_decoded = vae.decoder(dec_out[:, :1024, :])
#Image.fromarray(((vae_decoded[0].permute(1,2,0) + 0.5) * 255).to(torch.uint8).detach().numpy())

In [None]:
def log(t, eps = 1e-20):
    return torch.log(t + eps)
  
def gumbel_noise(t):
    noise = torch.zeros_like(t).uniform_(0, 1)
    return -log(-log(noise))

def gumbel_sample(t, temperature = 1., dim = -1):
    return ((t / temperature) + gumbel_noise(t)).argmax(dim = dim)

def top_k(logits, thres = 0.5):
    lg = logits[0][0]
    num_logits = lg.shape[-1]
    k = max(int((1 - thres) * num_logits), 1)
    val, ind = torch.topk(lg, k)
    probs = torch.full_like(lg, float('-inf'))
    probs[ind] = val
    probs = probs.unsqueeze(0).unsqueeze(0)
    return probs

In [None]:
import copy
from tqdm import tqdm
from PIL import Image
def generate(tsparti, src, tgt):
  #bs = src.shape[0]
  #seq_len = tgt.shape[1]
  #tgt_mask = torch.triu(torch.ones(seq_len, seq_len) * float('-inf'), diagonal=1).repeat(1*nhead, 1, 1).to(device)
  #tgt_mask = torch.ones(seq_len, seq_len).repeat(bs*nhead, 1, 1).to(device)
  src = tsparti.ts_projection(src)
  enc_out = tsparti.transformer_model.encoder(src)# * math.sqrt(tsparti.d_model)
  enc_out = tsparti.pos_encoder(enc_out)
  dec_out = tsparti.transformer_model.decoder(tgt, enc_out)
  logits = tsparti.to_logits(dec_out)
  return logits

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') #'cpu'
tsparti = tsparti.to(device)
vae = vae.to(device)

tgt = copy.deepcopy(tsparti.start_token.to(device))#torch.empty((1, 0), device = device, dtype = torch.long)#copy.deepcopy(tsparti.start_token.to(device))
with torch.no_grad():
  codes = tsparti.to_logits(tsparti.start_token).argmax(2)
print(codes)
for i, samp in enumerate(train_dataloader):
  src = samp['src'][0].unsqueeze(0).to(device)
  img = samp['img'][0].to(device)
  #src = torch.ones_like(src).to(device)
  for j in tqdm(range(0, im_seq_len)):
    with torch.no_grad():
      tgt = generate(tsparti, src, tgt)
      tgt = tgt[:, -1, :].unsqueeze(1)
      #token_code = tgt.argmax(2)
      filtered_logits = top_k(tgt, thres = 0.9)
      token_code = gumbel_sample(filtered_logits, temperature = 1, dim = -1)
      codes = torch.cat([codes, token_code], 1)
      tgt = vae.quantizer.embedding(codes)
      tgt = vae.quantizer.norm(tgt)
      
      #tgt = torch.cat([tgt, token], 1).to(device)
  break

codes = codes[:, 1:]
#dec_img = vae.decode(tgt[:, 1:, :])*255
#img = (dec_img[0].detach().cpu().permute(1,2,0).to(torch.uint8).numpy())
#img = Image.fromarray(img)

In [None]:
import IPython.display as ipd
with torch.no_grad():
  emb = vae.decode_codes(codes)
  #dec = vae.decode(emb)

ipd.display(show_im(emb[0]))


In [None]:
with torch.no_grad():
  emb = vae.decode_codes(tgt_codes)
  #dec = vae.decode(emb)

show_im(emb[0])

In [None]:
'''imgs = glob.glob('/content/out_files/ILSVRC/Data/CLS-LOC/train/*/*.JPEG')
csvs = glob.glob('/content/MindBigData-Imagenet/*.csv')

with open('/content/WordReport-v1.04.txt') as f:
  class_data = f.read()
class_data = list(map(lambda x: x.split('\t'), class_data.split('\n')))[:-1]
class_dict = {el[2]:el[0] for el in class_data}

csv_ids = list(map(lambda x: x[59:].rsplit('_', 2)[0], csvs))
img_ids = list(map(lambda x: x[55:-5], imgs))
ids = np.unique([el_mp for el_mp in csv_ids if el_mp in img_ids])

new_imgs, new_csvs = [], []
for id in ids:
  csv_matched = fnmatch.filter(csvs, '*' + id + '*.csv')
  img_matched = fnmatch.filter(imgs, '*' + id + '.JPEG')
  new_csvs += csv_matched
  new_imgs += img_matched * len(csv_matched)
  #break

labels = [class_dict[re.findall('n[0-9]+', el)[0]] for el in new_csvs]
new_csvs, new_imgs, labels = shuffle(new_csvs, new_imgs, labels, random_state=0)'''

In [None]:
'''tss_train = torch.cat([ds_train[i]['src'].unsqueeze(0) for i in range(len(ds_train))], 0)
torch.save(tss_train, tss_train_tensor_path)

tss_test = torch.cat([ds_test[i]['src'].unsqueeze(0) for i in range(len(ds_test))], 0)
torch.save(tss_test, tss_test_tensor_path)'''

In [None]:
'''text_token_embeds, text_mask = self.encode_texts(texts, output_device = device)

batch = text_token_embeds.shape[0]

image_seq_len = self.image_encoded_dim ** 2

image_tokens = torch.empty((batch, 0), device = device, dtype = torch.long)

for _ in range(image_seq_len):
    logits = self.forward_with_cond_scale(
        text_token_embeds = text_token_embeds,
        text_mask = text_mask,
        image_token_ids = image_tokens
    )[:, -1]

    filtered_logits = top_k(logits, thres = filter_thres)
    sampled = gumbel_sample(filtered_logits, temperature = temperature, dim = -1)

    sampled = rearrange(sampled, 'b -> b 1')
    image_tokens = torch.cat((image_tokens, sampled), dim = -1)

image_tokens = rearrange(image_tokens, 'b (h w) -> b h w', h = self.image_encoded_dim)'''

In [None]:
'''!pip install transformers==4.11.2
!pip install --upgrade sentencepiece
from transformers import T5Tokenizer, T5EncoderModel, T5Config
name = 'google/t5-v1_1-base'
tokenizer = T5Tokenizer.from_pretrained(name)
encoded = tokenizer.batch_encode_plus(
        ['Customize your Wells Fargo Debit Card. Choose from over 30 collegiate designs'],
        return_tensors = "pt",
        padding = 'longest',
        max_length = 256,
        truncation = True
    )'''

In [None]:
'''gc.collect()
#torch.cuda.empty_cache()

image_path="https://jw-webmagazine.com/wp-content/uploads/2020/03/Kimetsu-no-YaibaDemon-Slayer.jpg"
if "https" in image_path:
  original=download_image(image_path)
else:
  if os.path.exists(image_path):
    original=Image.open(image_path)
  else:
    print("Please check the image path")

encoder = {'dim': 768, 'depth': 12,
           'heads': 12, 'mlp_dim': 3072}
decoder = {'dim': 768, 'depth': 12,
           'heads': 12, 'mlp_dim': 3072}
quantizer = {'embed_dim': 32, 'n_embed': 8192}

image, image_t = preprocess(original)
image_t = image_t#.cuda()
clear_output()
model = ViTVQ(256, 8, encoder, decoder, quantizer, path='./imagenet_vitvq_base.ckpt')#.cuda()
recon, _ = model(image_t)
recon = to_Pil(recon.squeeze(0))#.squeeze(0).permute(1,2,0).detach().cpu() * 255
#recon = recon.to(torch.uint8)

show_ims(recon,image)
print("original saved at vit_vqgan/original.png ")
print("reconstructed saved at vit_vqgan/reconstructed.png ")'''