In [1]:
%cd ../

/home/users/dmoreno2016/VisionTransformers


In [2]:
import os
import glob
import webdataset as wds
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import importlib
import io
import torch

import random

from src.data.LitData import LitData
from scripts.utils import *

In [3]:
os.environ['CUDA_VISIBLE_DEVICES'] = '2'

In [4]:
def get_shards_path(base_path, subset):
    shards_path = f'{base_path}/{subset}/fold_0'
    shard_files = [f for f in os.listdir(shards_path) if f.endswith(".tar.gz")]
    print(shard_files)
    
    if shard_files:
        min_index = min(int(f.split('-')[1].split('.')[0]) for f in shard_files)
        max_index = max(int(f.split('-')[1].split('.')[0]) for f in shard_files)
        dataset_url = f"{shards_path}/imgs_lc-{{{min_index:06d}..{max_index:06d}}}.tar.gz"
        return dataset_url
    else:
        raise FileNotFoundError(f"No shard files found in the directory: {shards_path}")

def fn_decode(sample):
    for key, value in sample.items():
        if key.endswith(".pth"):
            sample[key] = torch.load(io.BytesIO(value), weights_only=True)
        elif key.endswith(".txt"):
            sample[key] = value.decode("utf-8") 
        elif key.endswith(".cls"):
            try:
                sample[key] = int(value)
            except ValueError:
                sample[key] = 0
    return sample

def get_input_model(sample):
    sample = fn_decode(sample)
    sample_mod = sample['pixel_values.pth']
    sample_mod = sample_mod.permute(1, 2, 0, 3)
    sample_mod = sample_mod.reshape(sample_mod.size(0), sample_mod.size(1), -1)
    sample['pixel_values.pth'] = sample_mod

    input_dict = {
        'id': sample['id.txt'],
        'pixel_values': sample['pixel_values.pth'],
        'y_true': sample['label.cls'],
    }
    return input_dict

def get_data(base_path, subset):
    dataset_url = get_shards_path(base_path, subset)
    return (
        wds.WebDataset(dataset_url, shardshuffle=False)
        .map(lambda sample: get_input_model(sample))
    )

base_path = f'data/images/macho/all/minmax_by_obj_256'

train_dataloader = torch.utils.data.DataLoader(get_data(base_path, 'train'), batch_size=50)
val_dataloader = torch.utils.data.DataLoader(get_data(base_path, 'val'), batch_size=50)
test_dataloader = torch.utils.data.DataLoader(get_data(base_path, 'test'), batch_size=50)

['imgs_lc-000140.tar.gz', 'imgs_lc-000234.tar.gz', 'imgs_lc-000320.tar.gz', 'imgs_lc-000023.tar.gz', 'imgs_lc-000022.tar.gz', 'imgs_lc-000421.tar.gz', 'imgs_lc-000332.tar.gz', 'imgs_lc-000442.tar.gz', 'imgs_lc-000138.tar.gz', 'imgs_lc-000351.tar.gz', 'imgs_lc-000287.tar.gz', 'imgs_lc-000363.tar.gz', 'imgs_lc-000369.tar.gz', 'imgs_lc-000307.tar.gz', 'imgs_lc-000305.tar.gz', 'imgs_lc-000027.tar.gz', 'imgs_lc-000191.tar.gz', 'imgs_lc-000100.tar.gz', 'imgs_lc-000040.tar.gz', 'imgs_lc-000315.tar.gz', 'imgs_lc-000124.tar.gz', 'imgs_lc-000456.tar.gz', 'imgs_lc-000107.tar.gz', 'imgs_lc-000084.tar.gz', 'imgs_lc-000114.tar.gz', 'imgs_lc-000306.tar.gz', 'imgs_lc-000169.tar.gz', 'imgs_lc-000208.tar.gz', 'imgs_lc-000353.tar.gz', 'imgs_lc-000123.tar.gz', 'imgs_lc-000000.tar.gz', 'imgs_lc-000225.tar.gz', 'imgs_lc-000302.tar.gz', 'imgs_lc-000061.tar.gz', 'imgs_lc-000009.tar.gz', 'imgs_lc-000008.tar.gz', 'imgs_lc-000073.tar.gz', 'imgs_lc-000049.tar.gz', 'imgs_lc-000196.tar.gz', 'imgs_lc-000082.tar.gz',

In [5]:
config = {
    'mlflow_dir': 'results/ml-runs',

    'checkpoint': {
        'use': True,
        'exp_name': 'ft_classification/alcock/testing',
        'run_name': '2024-09-08_13-53-49',
        'results_dir': 'results',
    },

    'pretrained_model': {
        'use': True,
        'path': "microsoft/swinv2-tiny-patch4-window8-256",
    },

    'loader': {
        'fold': 0
        }
}

def load_model(data_info, config):
    model_name = config['model_name']
    LitModel_module = importlib.import_module(f"src.models.LitModels.{model_name}")
    model = getattr(LitModel_module, 'LitModel')(data_info, **config)
    return model

ckpt_dir = handle_ckpt_dir(config, fold=config['loader']['fold'])
ckpt_model = sorted(glob.glob(ckpt_dir + "/*.ckpt"))[-1]
hparams = load_yaml(f'{ckpt_dir}/hparams.yaml')

In [6]:
data_info = hparams.pop('data_info')
loaded_model = load_model(data_info=data_info,
                          config=hparams)
if os.path.exists(ckpt_model):
    loaded_model = load_checkpoint(loaded_model, ckpt_model)            
else:
    raise FileNotFoundError(f"Checkpoint file not found at {ckpt_dir}")

# Definir el dispositivo (CPU o GPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
loaded_model.model.to(device)

2024-09-24 13:06:24.891333: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-09-24 13:06:24.908596: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-09-24 13:06:24.913874: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-09-24 13:06:24.926869: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
  checkpoint = torch.load(ckpt_model)


Swinv2Model(
  (embeddings): Swinv2Embeddings(
    (patch_embeddings): Swinv2PatchEmbeddings(
      (projection): Conv2d(3, 96, kernel_size=(4, 4), stride=(4, 4))
    )
    (norm): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
    (dropout): Dropout(p=0.0, inplace=False)
  )
  (encoder): Swinv2Encoder(
    (layers): ModuleList(
      (0): Swinv2Stage(
        (blocks): ModuleList(
          (0-1): 2 x Swinv2Layer(
            (attention): Swinv2Attention(
              (self): Swinv2SelfAttention(
                (continuous_position_bias_mlp): Sequential(
                  (0): Linear(in_features=2, out_features=512, bias=True)
                  (1): ReLU(inplace=True)
                  (2): Linear(in_features=512, out_features=3, bias=False)
                )
                (query): Linear(in_features=96, out_features=96, bias=True)
                (key): Linear(in_features=96, out_features=96, bias=False)
                (value): Linear(in_features=96, out_features=96, bias=

In [None]:
df = []
for i, batch_data in enumerate(train_dataloader):
    inputs = loaded_model.processor(images=batch_data['pixel_values'],
                                    return_tensors="pt").to(device)
    outputs = loaded_model.model(**inputs)
    pooled_output = outputs[1].detach().cpu().numpy()
    data = {'id': batch_data['id'], 
            'embedding': [pooled_output[i] for i in range(pooled_output.shape[0])]} 
    df.append(pd.DataFrame(data))

df = pd.concat(df).reset_index(drop=True)
df.to_parquet('chunk_1.parquet')

In [9]:
df = []
for i, batch_data in enumerate(val_dataloader):
    inputs = loaded_model.processor(images=batch_data['pixel_values'],
                                    return_tensors="pt").to(device)
    outputs = loaded_model.model(**inputs)
    pooled_output = outputs[1].detach().cpu().numpy()
    data = {'id': batch_data['id'], 
            'embedding': [pooled_output[i] for i in range(pooled_output.shape[0])]} 
    df.append(pd.DataFrame(data))

df = pd.concat(df).reset_index(drop=True)
df.to_parquet('chunk_2.parquet')

In [7]:
df = []
for i, batch_data in enumerate(test_dataloader):
    inputs = loaded_model.processor(images=batch_data['pixel_values'],
                                    return_tensors="pt").to(device)
    outputs = loaded_model.model(**inputs)
    pooled_output = outputs[1].detach().cpu().numpy()
    data = {'id': batch_data['id'], 
            'embedding': [pooled_output[i] for i in range(pooled_output.shape[0])]} 
    df.append(pd.DataFrame(data))

df = pd.concat(df).reset_index(drop=True)
df.to_parquet('chunk_3.parquet')

In [9]:
pd.read_parquet('chunk_3.parquet')

Unnamed: 0,id,embedding
0,b'F_1.3321.139',"[-1.284629, -0.31749094, -0.6400238, 2.437619,..."
1,b'F_1.3321.171',"[0.21544454, -0.9262556, -3.7031307, 2.2348804..."
2,b'F_1.3322.117',"[-0.96317387, -2.6845708, -0.02852206, 2.59417..."
3,b'F_1.3322.352',"[-0.65112597, -2.2339785, 0.6990346, -0.492593..."
4,b'F_1.3323.128',"[-0.19724986, -2.131647, 0.14896852, 1.3611845..."
...,...,...
152935,b'F_104.21423.320',"[-0.009841576, -3.6842074, 1.8758022, 1.780292..."
152936,b'F_104.21423.442',"[2.7456453, -3.3650455, 0.90351087, 1.8550711,..."
152937,b'F_104.21424.255',"[-0.7771113, -0.5712862, 1.3037164, -0.1520346..."
152938,b'F_104.21424.378',"[0.47180226, -2.4642267, 1.1030608, 1.0658232,..."


In [None]:
outputs[1].shape

torch.Size([2, 768])

In [None]:
outputs['last_hidden_state'].shape

torch.Size([2, 64, 768])

In [None]:
outputs[1]

tensor([[-1.5579,  0.1109,  3.1001,  ...,  0.9162,  3.7436, -2.5494],
        [ 2.1747,  0.3100, -0.3583,  ..., -0.7695,  3.8617, -2.1436]],
       grad_fn=<ViewBackward0>)