In [1]:
import os
import torch
import timm
import pandas as pd
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import albumentations as A
from PIL import Image
from sklearn.model_selection import KFold
from sklearn.metrics import r2_score
import numpy as np
from tqdm import tqdm
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, OneCycleLR, CosineAnnealingLR, ReduceLROnPlateau, StepLR, LambdaLR

from utils import seed_torch, current_date_time, init_logger

# os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ["HTTP_PROXY"] = "http://127.0.0.1:7890"
os.environ["HTTPS_PROXY"] = "http://127.0.0.1:7890"


class CFG:
    seed = 42
    num_workers = 4

    batch_size = 32
    
    model_name = 'vit_small_patch14_reg4_dinov2.lvd142m' # efficientnet_b0 swin_large_patch4_window7_224 swin_tiny_patch4_window7_224 vit_giant_patch14_reg4_dinov2.lvd142m vit_small_patch14_reg4_dinov2.lvd142m
    img_size = 224 # 128 224 518

   
seed_torch(CFG.seed)
cur_time = current_date_time()
cur_time_abbr = cur_time.replace("-", "").replace(":", "").replace(" ", "")[4:12]
output_dir = './output'
output_dir = f"{output_dir}/{cur_time_abbr}_get_embedding"
os.makedirs(output_dir, exist_ok=True)
LOGGER = init_logger(f'{output_dir}/get_embedding.log')

for key, value in CFG.__dict__.items():
    if not key.startswith("__"):
        LOGGER.info(f"{key} = {value}")


# 读取数据
train_df = pd.read_csv('./data/train.csv')
test_df = pd.read_csv('./data/test.csv')

train_df['file_path'] = train_df['id'].apply(lambda x: f'./data/train_images/{x}.jpeg')
test_df['file_path'] = test_df['id'].apply(lambda x: f'./data/test_images/test_images/{x}.jpeg')

print(f"{train_df.shape = }")
print(f"{test_df.shape = }")

# 设置设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"{device = }")

seed = 42
num_workers = 4
batch_size = 32
model_name = vit_small_patch14_reg4_dinov2.lvd142m
img_size = 224


train_df.shape = (43363, 170)
test_df.shape = (6391, 164)
device = device(type='cuda')


Using cache found in /home/br/.cache/torch/hub/facebookresearch_dinov2_main


DinoVisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 1536, kernel_size=(14, 14), stride=(14, 14))
    (norm): Identity()
  )
  (blocks): ModuleList(
    (0-39): 40 x NestedTensorBlock(
      (norm1): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)
      (attn): MemEffAttention(
        (qkv): Linear(in_features=1536, out_features=4608, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=1536, out_features=1536, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): LayerScale()
      (drop_path1): Identity()
      (norm2): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)
      (mlp): SwiGLUFFNFused(
        (w12): Linear(in_features=1536, out_features=8192, bias=True)
        (w3): Linear(in_features=4096, out_features=1536, bias=True)
      )
      (ls2): LayerScale()
      (drop_path2): Identity()
    )
  )
  (norm): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)
  (head

In [2]:
model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitg14_reg').to(device)
model.eval()

# 数据增强和预处理
transform = transforms.Compose([
    transforms.Resize(CFG.img_size, interpolation=3),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])

In [6]:
def get_image_embeddings_dino(model, preprocess, batch_size, df):
    image_embeddings = []
    for i in tqdm(range(0, len(df), batch_size)):
        paths = df['file_path'][i:i + batch_size]
        image_tensor = torch.stack([preprocess(Image.open(path)) for path in paths]).to(device)
        with torch.no_grad():
            curr_image_embeddings = model(image_tensor)
        image_embeddings.extend(curr_image_embeddings.cpu().numpy())
    return image_embeddings

In [7]:
train_image_embeddings = get_image_embeddings_dino(model, transform, CFG.batch_size, train_df)
np.save(f'train_dinov2_embeds', np.array(train_image_embeddings))

100%|██████████| 1356/1356 [29:25<00:00,  1.30s/it]


In [12]:
test_image_embeddings = get_image_embeddings_dino(model, transform, CFG.batch_size, test_df)
np.save(f'test_dinov2_embeds', np.array(test_image_embeddings))

100%|██████████| 200/200 [04:15<00:00,  1.28s/it]
