In [5]:
# 📦 1. Importações
import os
import torch
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt

# Seu modelo e utilitários
from model import DeiTForFewShot 

In [4]:
# 🧼 2. Transformações de pré-processamento
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])


In [9]:
# Diretórios
suporte_dir = "dataset_suporte"
consulta_dir = "dataset_consulta"

# Transformações (compatível com DeiT)
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # ou o que o seu backbone espera
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

# Listar subpastas (classes)
classes = sorted(os.listdir(suporte_dir))
print("Classes detectadas no suporte:", classes)

support_images = []
support_labels = []

# Atribuir um índice para cada classe
classe_para_indice = {classe: idx for idx, classe in enumerate(classes)}

# Percorrer cada classe e carregar imagens
for classe in classes:
    classe_path = os.path.join(suporte_dir, classe)
    for nome_img in sorted(os.listdir(classe_path)):
        img_path = os.path.join(classe_path, nome_img)
        imagem = Image.open(img_path).convert("RGB")
        support_images.append(transform(imagem))
        support_labels.append(classe_para_indice[classe])

# Converter para tensores
support_images = torch.stack(support_images)  # [N_supp, 3, 224, 224]
support_labels = torch.tensor(support_labels) # [N_supp]

# Agrupar por batch de 1
support_images = support_images.unsqueeze(0)  # [1, N_supp, 3, 224, 224]
support_labels = support_labels               # [N_supp]


Classes detectadas no suporte: ['A', 'Airfield', 'Airplane', 'Avião', 'Banheiro', 'Bus', 'C', 'Cafeteria', 'Camarim', 'Carro', 'Casa', 'Castelo', 'Castle', 'Estrada', 'Highway', 'Piscina', 'Quarto', 'Quarto de Criança', 'Sala de aula', 'Sala de estar', 'Ônibus']


In [10]:
# Carregamento da consulta
query_images = torch.stack([
    transform(Image.open(os.path.join(consulta_dir, img)).convert("RGB"))
    for img in sorted(os.listdir(consulta_dir))
])
query_images = query_images.unsqueeze(0)  # [1, N_query, 3, 224, 224]


In [14]:
import argparse

def get_args():
    parser = argparse.ArgumentParser('Few-shot learning script', add_help=False)
    # (copie aqui o conteúdo do get_args_parser, só os add_argument)
    
    # --- Copie todos os add_argument daqui (ou importe diretamente do seu módulo) ---
    parser.add_argument('--batch-size', default=1, type=int)
    parser.add_argument('--epochs', default=100, type=int)
    parser.add_argument('--output_dir', default='outputs/tmp',
                      help='path where to save, empty for no saving')
    parser.add_argument('--seed', default=0, type=int)
    parser.add_argument('--deterministic', default=False, type=bool)
    parser.add_argument('--experiment_name', default="", help='name of the experiment running to create apropriate file.')
    
    parser.add_argument("--extract_features", action='store_true')
    parser.add_argument("--dataset_path", default="", type=str)

    parser.add_argument("--classify", action="store_true")
    parser.add_argument("--support_file", type=str)
    parser.add_argument("--query_file", type=str)

    parser.add_argument("--wandb", dest='wandb', action='store_true')
    parser.add_argument("--no-wandb", dest='wandb', action='store_false')
    parser.set_defaults(wandb=True)
    parser.add_argument("--project-name", default="FSL-Transformers", type=str)

    parser.add_argument("--dataset", choices=["places", "places_600", "test", "final_test", "csam", "litmus"],
                      default="places_600",
                      help="Which few-shot dataset.")

    parser.add_argument("--nClsEpisode", default=8, type=int,
                      help="Number of categories in each episode.")
    parser.add_argument("--nSupport", default=5, type=int,
                      help="Number of samples per category in the support set.")
    parser.add_argument("--nQuery", default=15, type=int,
                      help="Number of samples per category in the query set.")
    parser.add_argument("--nEpisode", default=2000, type=int,
                      help="Number of episodes for training / testing.")

    parser.add_argument('--pretrained_weights', default='', type=str, help="Path to pretrained weights to evaluate.")
    parser.add_argument('--backbone', default='deit_small',choices=['deit', 'resnet50', 'dino', 'resnet50_dino', 'deit_small', 'resnet18', 'vit_mini'])
    parser.add_argument('--aggregator', default='average', choices=['average', 'max', 'log_sum_exp', 'lp_pool', 'self_attn'])
    parser.add_argument('--temperature', default=0.1, type=float, help='temperature to be applyed to cosine similarities')

    parser.add_argument('--aug_prob', default=0.9, type=float, help='Probability of applying data augmentation during meta-testing')
    parser.add_argument('--aug_types', nargs="+", default=['color', 'translation'],
                      help='color, offset, offset_h, offset_v, translation, cutout')

    parser.add_argument('--img-size', default=224, type=int, help='images input size')

    parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
                      help='SGD momentum (default: 0.9)')
    parser.add_argument('--optimizer', type=str, default='sgd', choices=['sgd', 'adam'])

    parser.add_argument('--sched', default='step', type=str, choices=["cosine", "step", "exponential", "None"], metavar='SCHEDULER',
                      help='LR scheduler (default: "step"')
    parser.add_argument('--lr', type=float, default=5e-5, metavar='LR',
                      help='learning rate (default: 5e-4)')
    parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
                      help='learning rate noise on/off epoch percentages')
    parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',
                      help='learning rate noise limit percent (default: 0.67)')
    parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
                      help='learning rate noise std-dev (default: 1.0)')
    parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR',
                      help='warmup learning rate (default: 1e-6)')
    parser.add_argument('--min-lr', type=float, default=1e-6, metavar='LR',
                      help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')

    parser.add_argument('--decay-epochs', type=float, default=10, metavar='N',
                      help='epoch interval to decay LR (step scheduler)')
    parser.add_argument('--warmup-epochs', type=int, default=5, metavar='N',
                      help='epochs to warmup LR, if scheduler supports')
    parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N',
                      help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
    parser.add_argument('--patience-epochs', type=int, default=10, metavar='N',
                      help='patience epochs for Plateau LR scheduler (default: 10')
    parser.add_argument('--decay-rate', '--dr', type=float, default=0.5, metavar='RATE',
                      help='LR decay rate (default: 0.1)')

    parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',
                      help='Color jitter factor (default: 0.4)')
    parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',
                      help='Use AutoAugment policy. "v0" or "original". " + \
                            "(default: rand-m9-mstd0.5-inc1)'),
    parser.add_argument('--smoothing', type=float, default=0.0, help='Label smoothing (default: 0.1)')
    parser.add_argument('--train-interpolation', type=str, default='bicubic',
                      help='Training interpolation (random, bilinear, bicubic default: "bicubic")')

    parser.add_argument('--repeated-aug', action='store_true')

    parser.add_argument('--resume', default='', help='resume from checkpoint')
    parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
                      help='start epoch')
    parser.add_argument('--max_acc', default=None, type=tuple, help='Max accuracy obtained in training before')
    parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
    parser.add_argument('--num_workers', default=10, type=int)
    parser.add_argument('--pin-mem', action='store_true',
                      help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
    parser.set_defaults(pin_mem=True)

    parser.add_argument('--device', default='cuda',
                      help='cuda:gpu_id for single GPU training')

    # --- Fim dos argumentos ---

    args = parser.parse_args([])  # simula execução sem args da linha de comando

    return args

args = get_args()
print("Args configurado com sucesso!")


Args configurado com sucesso!


In [21]:
import torch
from model import DeiTForFewShot  # ajuste o import conforme seu projeto

# Carregando o modelo
model = DeiTForFewShot(args)

# Substitua pelo caminho correto do seu arquivo .pth
checkpoint_path = "proxyfsl/thamiris_FSL_places600_best.pth"

checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False)

# checkpoint é um dicionário com várias chaves, a parte do modelo fica em 'model'
state_dict = checkpoint.get('model', checkpoint)  # tenta pegar o 'model', senão usa todo

model.load_state_dict(state_dict)
model.eval()

print("Modelo carregado com sucesso!")

Some weights of DeiTModel were not initialized from the model checkpoint at facebook/deit-small-distilled-patch16-224 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Modelo carregado com sucesso!


In [26]:
image = Image.open("proxyfsl/highway.jpg").convert("RGB")
image = transform(image).unsqueeze(0) 

with torch.no_grad():
    embedding = model.get_features(image)

len(embedding[0])

384