In [1]:
import torch
import torch.nn as nn

In [2]:
transformer_model = nn.Transformer(nhead=16, num_encoder_layers=12)
src = torch.rand((10, 32, 512))
tgt = torch.rand((20, 32, 512))
out = transformer_model(src, tgt)

In [3]:
out.shape

torch.Size([20, 32, 512])

In [4]:
encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
src = torch.rand(10, 32, 512)
out = transformer_encoder(src)

In [5]:
# transformer_encoder

In [6]:
src[0,0,:10]

tensor([0.9591, 0.8684, 0.5397, 0.3370, 0.3336, 0.8257, 0.9730, 0.7644, 0.9825,
        0.5655])

In [7]:
out[0,0,:10]

tensor([ 0.1918,  1.0481, -0.3919,  0.6234, -1.2013, -1.0217,  1.0253,  0.8566,
        -1.6811,  0.0522], grad_fn=<SliceBackward0>)

In [8]:
import numpy as np

In [9]:
user_train = np.load('../data/user_train.npy', allow_pickle=True).item()

In [10]:
seq = user_train['sequences']['01FJRKCP4GE1W1DFX51C']['keypoints']
seq.shape

(4500, 11, 24, 2)

In [11]:
np.isnan(seq).sum()

432000

In [12]:
seq_norm = seq.copy()
seq_norm[np.isnan(seq_norm)] = 0
seq_len, num_flies, _, _ = seq_norm.shape
seq_norm = seq_norm.reshape(seq_len, num_flies, -1)
seq_norm = torch.Tensor(seq_norm)
seq_norm.shape

torch.Size([4500, 11, 48])

In [13]:
torch.isnan(seq_norm).sum()

tensor(0)

In [14]:
encoder_layer = nn.TransformerEncoderLayer(d_model=48, nhead=8)
model = nn.TransformerEncoder(encoder_layer, num_layers=6)

In [25]:
from tqdm.auto import tqdm

In [29]:
DEVICE = 'cuda'

In [37]:
bs = 64

model = model.to(DEVICE)
seq_norm = seq_norm.to(DEVICE)

batches = torch.chunk(seq_norm, chunks=seq_norm.size(0) // (bs - 1), dim=0)

outs = []
for batch in tqdm(batches):
    out = model(batch)
    outs.append(out)
outs = torch.cat(outs)

  0%|          | 0/71 [00:00<?, ?it/s]

In [36]:
outs.shape

torch.Size([4500, 11, 48])

In [2]:
from transformers import ViTModel, ViTConfig

In [3]:
# Initializing a ViT vit-base-patch16-224 style configuration
configuration = ViTConfig()

# Initializing a model from the vit-base-patch16-224 style configuration
model = ViTModel(configuration)

# Accessing the model configuration
configuration = model.config

In [4]:
model

ViTModel(
  (embeddings): ViTEmbeddings(
    (patch_embeddings): PatchEmbeddings(
      (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    )
    (dropout): Dropout(p=0.0, inplace=False)
  )
  (encoder): ViTEncoder(
    (layer): ModuleList(
      (0): ViTLayer(
        (attention): ViTAttention(
          (attention): ViTSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
          (output): ViTSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
        )
        (intermediate): ViTIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
          (intermediate_act_fn): GELUActivation()
        )

In [14]:
from transformers import ViTFeatureExtractor, ViTModel
import torch
from datasets import load_dataset

dataset = load_dataset("huggingface/cats-image")
image = dataset["test"]["image"][0]

feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
model = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")

inputs = feature_extractor(image, return_tensors="pt")

with torch.no_grad():
    outputs = model(**inputs)

last_hidden_states = outputs.last_hidden_state
list(last_hidden_states.shape)

No config specified, defaulting to: cats_image/image
Reusing dataset cats_image (/home/iz/.cache/huggingface/datasets/huggingface___cats_image/image/1.9.0/68fbc793fb10cd165e490867f5d61fa366086ea40c73e549a020103dcb4f597e)


  0%|          | 0/1 [00:00<?, ?it/s]

[1, 197, 768]

In [18]:
model

ViTModel(
  (embeddings): ViTEmbeddings(
    (patch_embeddings): PatchEmbeddings(
      (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    )
    (dropout): Dropout(p=0.0, inplace=False)
  )
  (encoder): ViTEncoder(
    (layer): ModuleList(
      (0): ViTLayer(
        (attention): ViTAttention(
          (attention): ViTSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
          (output): ViTSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
        )
        (intermediate): ViTIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
          (intermediate_act_fn): GELUActivation()
        )

In [17]:
outputs

BaseModelOutputWithPooling(last_hidden_state=tensor([[[ 0.1559,  0.0914,  0.1518,  ..., -0.3180, -0.0859, -0.0903],
         [-0.2254,  0.0864,  0.4752,  ..., -0.1781,  0.1726,  0.1334],
         [ 0.0444,  0.0677,  0.4199,  ..., -0.2576,  0.1191,  0.0130],
         ...,
         [-0.0153, -0.0396,  0.1684,  ..., -0.1672,  0.1869,  0.1025],
         [ 0.0249, -0.0382,  0.2046,  ...,  0.0517,  0.1489,  0.1320],
         [-0.1748, -0.0254,  0.2523,  ..., -0.1474,  0.1627,  0.1325]]]), pooler_output=tensor([[ 5.8399e-02, -3.0683e-01,  3.1213e-01, -1.1009e-01, -1.4752e-01,
          4.9735e-01, -1.5786e-01,  4.8658e-01, -4.6255e-01,  2.4344e-01,
          2.9941e-02,  2.8738e-01, -4.8914e-01, -9.9516e-03, -2.8943e-01,
          3.1443e-01, -6.2883e-02, -2.6637e-01, -3.9652e-01,  2.9896e-01,
          2.1507e-01, -1.9265e-01,  1.1786e-01,  2.5995e-01,  3.5440e-01,
         -3.7968e-01,  4.8320e-01, -3.5686e-01,  2.3996e-01, -8.0731e-01,
          1.1701e-02,  4.9429e-01,  6.9714e-01,  5.273

In [15]:
inputs['pixel_values'].shape

torch.Size([1, 3, 224, 224])