In [None]:
import os
#os.environ['CUDA_VISIBLE_DEVICES']="1"

import matplotlib.pyplot as plt

from pathlib import Path
from fastai import *
from fastai.vision.all import *
from fastai.callback import *
from fastai.data.transforms import get_image_files
import pandas as pd
from arch import RingGeMNet, GeMNet, L2Norm, GeM
import re

from fastprogress import master_bar, progress_bar
%config InlineBackend.figure_format ='retina'

In [None]:
# ./index -> points to index dir
COMP_DATA_DIR = Path('.')

df = pd.DataFrame({'Image' : sorted(get_image_files(COMP_DATA_DIR / 'index', recurse=True))})

In [None]:
df.head()

In [None]:
NUM_WORKERS=8
SIZE = 256
DO_FULL_SIZE = False 

item_tfms = None if DO_FULL_SIZE else Resize(SIZE, method='squish')  # Biến đổi kích thước hình ảnh
batch_tfms = Normalize.from_stats(*imagenet_stats)  # Chuẩn hóa ảnh với ImageNet

# Định nghĩa DataBlock không chia tập dữ liệu
data_block = DataBlock(
    blocks=(ImageBlock(),),  # Chỉ dùng ImageBlock, không có nhãn
    get_x=lambda row: row['Image'],  # Trả về đường dẫn ảnh dưới dạng chuỗi
    splitter=RandomSplitter(valid_pct=0),  # Không chia tập dữ liệu
    item_tfms=item_tfms,  # Phép biến đổi kích thước
    batch_tfms=batch_tfms  # Chuẩn hóa ảnh
)

# Tạo DataLoaders
BS = 1 if DO_FULL_SIZE else 64
data = data_block.dataloaders(df, bs=BS, num_workers=NUM_WORKERS)

# Hiển thị một batch hình ảnh
data.show_batch(max_n=9, figsize=(8, 8))


In [None]:
# In ra thông tin của DataLoader
print("Batch size:", data.bs)
print("Number of training samples:", len(data.train_ds))
print("Number of validation samples:", len(data.valid_ds) if hasattr(data, 'valid_ds') else 0)
print("Number of training batches:", len(data))



In [None]:
# Kiểm tra các thuộc tính của DataLoaders
print(data)

# Kiểm tra batch đầu tiên trong train
for batch in data.train:
    print(batch)  # In ra batch đầu tiên
    break

In [None]:
import torchvision
import pretrainedmodels

In [None]:
arch = partial(pretrainedmodels.se_resnet152, num_classes=1000) 

arch.__name__ = arch.func.__name__
model_fname =  None #'resnet152_i200_l1000-256'
basename_suffix = 'cut-extractor-2scales6patches-gem3'
size_fname = 'full' if DO_FULL_SIZE else str(SIZE)

basename = f'{model_fname or arch.__name__}_{size_fname}_{basename_suffix}.pth'
print(basename)

In [None]:
class Extractor(nn.Module):
    def __init__(self):
        super().__init__()
        self.l2norm = L2Norm()
        self.pool   = GeM(3.) #nn.AdaptiveMaxPool2d(1)
    def forward(self, x):
        b,d,ny,nx = x.shape
        
        f0  = self.l2norm(self.pool(x)).view(b,1,d)
        # uncomment if you want to extract multiple patches here
        
        #f1  = self.l2norm(self.pool(x[...,ny//2-ny//4:ny//2+ny//4,nx//2-nx//4:nx//2+nx//4])).view(b,1,d)

        #f1_x0y0 = self.l2norm(self.pool(x[...,:ny//2,:nx//2])).view(b,1,d)
        #f1_x0y1 = self.l2norm(self.pool(x[...,ny//2:,:nx//2])).view(b,1,d)
        #f1_x1y0 = self.l2norm(self.pool(x[...,:ny//2,nx//2:])).view(b,1,d)
        #f1_x1y1 = self.l2norm(self.pool(x[...,ny//2:,nx//2:])).view(b,1,d)
        
        #return torch.cat((f0,f1,f1_x0y0,f1_x0y1,f1_x1y0,f1_x1y1), dim=1)
        #return torch.cat((f0,f1), dim=1)
        
        return f0

In [None]:
learn = vision_learner(data, arch,pretrained='imagenet', custom_head=Extractor(),
                   metrics=[accuracy], cut= -1,
                   loss_func=nn.CrossEntropyLoss(), n_out=1000)

In [None]:
if model_fname:
    learn = learn.load(model_fname, strict=False)
else:
    model_fname = arch.__name__


In [None]:
learn.summary()

In [None]:
InferenceNet =  learn.model

In [None]:
NUM_WORKERS=8

qdf = pd.DataFrame({'Image' : sorted(get_image_files(COMP_DATA_DIR / 'test', recurse=True))})
qdf.head()

In [None]:
# Biến đổi kích thước hình ảnh
item_tfms = Resize(SIZE, method='squish') if not DO_FULL_SIZE else None
# Chuẩn hóa ảnh với ImageNet
batch_tfms = Normalize.from_stats(*imagenet_stats)

# Định nghĩa DataBlock
data_block = DataBlock(
    blocks=(ImageBlock(),),  # Chỉ dùng ImageBlock, không có nhãn
    get_x=lambda row: row['Image'],  # Trả về đường dẫn ảnh dưới dạng chuỗi
    splitter=RandomSplitter(valid_pct=0),  # Không chia tập dữ liệu
    item_tfms=item_tfms,  # Phép biến đổi kích thước
    batch_tfms=batch_tfms  # Chuẩn hóa ảnh
)

# Tạo DataLoaders
BS = 1 if DO_FULL_SIZE else 64
qdata = data_block.dataloaders(qdf, bs=BS, num_workers=NUM_WORKERS)

# Hiển thị một batch hình ảnh
qdata.show_batch(max_n=9, figsize=(8, 8))

# Nếu bạn muốn điều chỉnh bộ mẫu
qdata.train.sampler = torch.utils.data.SequentialSampler(qdata.train_ds)
qdata.train.drop_last = False

In [None]:
# In ra thông tin của DataLoader
print("Batch size:", qdata.bs)
print("Number of training samples:", len(qdata.train_ds))
print("Number of validation samples:", len(qdata.valid_ds) if hasattr(qdata, 'valid_ds') else 0)
print("Number of training batches:", len(qdata))



In [None]:
from tqdm import tqdm  # import thư viện tqdm

device = torch.device("cpu")  # Chỉ sử dụng CPU

def extract_vectors_batched(data, model, flip=False):
    model.to(device)
    model.eval()
    n_flip = 2 if flip else 1
    n_img = len(data.train_ds) * n_flip
    bs = data.bs
    vectors = None

    with torch.no_grad():
        # Thêm tqdm vào vòng lặp for để theo dõi tiến trình
        for idx, (img) in enumerate(tqdm(data.train, desc="Processing images", unit="batch")):
            st = idx * bs * n_flip
            fin = min((idx + 1) * bs * n_flip, n_img)
            if flip:
                img = torch.cat((img[0], img[0].flip([3])))
            out = model(img).cpu()
            if vectors is None:
                vectors = torch.zeros(n_img, *out.shape[1:])
            if flip:
                n = fin - st
                vectors[st:fin:2, ...] = out[:n // 2, ...]
                vectors[st + 1:fin + 1:2, ...] = out[n // 2:, ...]
            else:
                vectors[st:fin, ...] = out
    return vectors


In [None]:
flip = True
p_flip = 'flip' if flip else ''
try:
    index_features = torch.load(( f'index{p_flip}_{basename}'), weights_only=True)
except:
    print("Tệp tin không tồn tại hoặc bị lỗi, đang tạo lại index_features.")
    index_features = extract_vectors_batched(data, InferenceNet, flip)
    torch.save(index_features, f'index{p_flip}_{basename}')

In [None]:
flip = True
p_flip = 'flip' if flip else ''
try:
    query_features = torch.load(( f'query{p_flip}_{basename}'),weights_only=True)
except:
    print("Tệp tin không tồn tại hoặc bị lỗi, đang tạo lại query_features.")
    query_features = extract_vectors_batched(qdata,InferenceNet, flip)
    torch.save(query_features, f'query{p_flip}_{basename}')


In [None]:
query_features

In [None]:
index_features

In [None]:
import faiss
def flatten(list2d): return list(itertools.chain(*list2d))

# duplicate b/c we're going to have image, image_LR, image, image_LR, ...
query_fnames = flatten([[x.stem, x.stem] for x in qdf.Image.tolist()])
index_fnames = [x.stem for x in df.Image.tolist()]


In [None]:
import gc
learn, InferenceNet, co, res, flat_config, cpu_index, index = None, None, None, None, None, None, None
gc.collect()

In [None]:
def t_pcawhitenlearn(X):

    N = X.shape[0]

    # Learning PCA w/o annotations
    m = X.mean(dim=0, keepdim=True)
    Xc = X - m
    Xcov = Xc.t() @ Xc
    Xcov = (Xcov + Xcov.t()) / (2*N)
    eigval, eigvec = torch.symeig(Xcov,eigenvectors=True)
    order = eigval.argsort(descending=True)
    eigval = eigval[order]
    eigvec = eigvec[:, order]

    P = torch.inverse(torch.sqrt(torch.diag(eigval))) @ eigvec.t()
    
    return m, P

def t_whitenapply(X, m, P, dimensions=None):
    
    if not dimensions: dimensions = P.shape[1]

    X = (X-m) @ P[:,:dimensions]
    X = X / (torch.norm(X, dim=1, keepdim=True) + 1e-6)
    return X

def get_idxs_and_dists(_query_features, _index_features, index_type='', BS = 32):
    
    # if I do PCA and whitenining here I get different results than if doing it by faiss, why?
    # hence I had to disable it and use faiss which resorts to CPU then (slower)
    if False:
        index_transforms = []
        for index_transform in index_type.split(','):
            m = re.match(r'PCAW(\d+)?', index_transform)
            if m is not None:
                dimensions = int(m[1]) if m[1] is not None else _index_features.shape[-1]
                print(f"Applying {dimensions} PCA, Whitening and L2Norm...", end="")
                m, P = t_pcawhitenlearn(_index_features)
                _index_features = t_whitenapply(_index_features, m, P,dimensions=dimensions).unsqueeze(1)
                _query_features = t_whitenapply(_query_features, m, P,dimensions=dimensions).unsqueeze(1)
                print("done")

            elif index_transform not in ['L2norm']: index_transforms.append(index_transform)

        index_type = ','.join(index_transforms)
        print(index_type)
    else:
         _index_features = _index_features.unsqueeze(1)
         _query_features = _query_features.unsqueeze(1)
        
    if isinstance(_query_features, Tensor): query_features = _query_features.numpy()
    if isinstance(_index_features, Tensor): index_features = _index_features.numpy()
    max_hits = 200
    
    n_patches = query_features.shape[1]
    n_queries = query_features.shape[0]

    query_features = query_features[:,::n_patches,:].squeeze(1).copy()
    index_features = index_features[:,::n_patches,:].squeeze(1).copy()    
    n_patches = 1

    print(query_features.shape, index_features.shape, n_queries, n_patches)
    
    flat_config = faiss.GpuIndexFlatConfig()
    flat_config.device = 0
    res = faiss.StandardGpuResources()
    co = faiss.GpuMultipleClonerOptions()
    co.shard=True
    co.shard_type=1
    co.useFloat16=True
    _index = faiss.index_factory(index_features.shape[1], index_type)#, faiss.METRIC_INNER_PRODUCT)
    try:
        faiss.index_cpu_to_all_gpus(_index,co=co)
        print("Index in GPU")
    except:
        index = _index
        print("Index in CPU")
    print("Training index...", end="")
    index.train(index_features)
    print("done")
    print("Adding features to index...", end="")
    index.add(index_features)
    print("done")
    out_dists = np.zeros((len(query_features), max_hits), dtype=np.float32)
    out_idxs  = np.zeros((len(query_features), max_hits), dtype=np.int32)
    NUM_QUERY = len (query_features)
    for ind in progress_bar(range(0, len(query_features), BS)):
        fin = ind+BS
        if fin > NUM_QUERY: fin = NUM_QUERY
        q_descs = query_features[ind:fin]
        D, I = index.search(q_descs, max_hits)
        out_dists[ind:fin] = D
        out_idxs[ind:fin] = I // n_patches
    return out_idxs, out_dists

In [None]:
faiss.omp_get_max_threads()

In [None]:
query_features.shape

In [None]:
index_type=f"PCAW{query_features.shape[-1]},L2norm,Flat"

try:
    out_idxs  = np.load(f'idx_{basename}.npy')
    out_dists = np.load(f'dist_{basename}.npy')
except:
    out_idxs, out_dists = get_idxs_and_dists(
        query_features.squeeze(1), 
        index_features.squeeze(1), BS = 32, index_type=index_type)
    np.save(f'idx_{basename}.npy',  out_idxs)
    np.save(f'dist_{basename}.npy', out_dists)

In [None]:
np.sort(out_dists.reshape((-1,int(out_idxs.shape[1]*1))), axis=1)

In [None]:
out_idxs.shape, out_dists.shape

In [None]:
sub_fname = 'test_submission.csv'
sample_df = pd.read_csv('/home/cvdcl/22521167/test.csv')
sample_df['images'] = ''

In [None]:
sub = {}
for i, query_fname in progress_bar(enumerate(query_fnames), total=len(query_fnames)):
    if i % 2: continue
    idx = np.concatenate([out_idxs[i], out_idxs[i+1]], axis=0)//2
    dst = np.concatenate([out_dists[i],out_dists[i+1]], axis=0) 
    u_idx = np.unique(idx,return_index=True)[1]
    i_dst = dst[u_idx]
    o_dst =np.argsort(i_dst)
    _out_idxs = idx[u_idx[o_dst]]

    ids = [index_fnames[x] for x in _out_idxs[:100]]
    sub[query_fname] = ' '.join(ids)

In [None]:
# In ra danh sách sub
for query_fname, ids in sub.items():
    print(f"Query file: {query_fname}")
    print(f"Top 100 nearest neighbors: {ids}")
    print()

In [None]:
sub_df = pd.DataFrame({'id' : list(sub.keys()), 'images':list(sub.values())})
sub_df = pd.concat([sub_df, sample_df]).drop_duplicates(subset=['id'])
sub_df.to_csv(sub_fname, index=False)

In [None]:
sub_df.iloc[:8]

In [None]:
from pathlib import Path
from PIL import Image
import matplotlib.pyplot as plt

In [None]:
from fastai.vision.all import *
import matplotlib.pyplot as plt
from pathlib import Path

def fix_path(p):
    fn = str(p.name)
    return p.parent / fn[0] / fn[1] / fn[2] / fn

def image_results(row, n=5):
    # Sử dụng PILImage.create thay vì open_image, với 1 ảnh query và 5 ảnh result
    r = [PILImage.create(fix_path(Path('test/test') / (row.id + '.jpg')))]  # Ảnh query
    r.extend([PILImage.create(fix_path(Path('index/index') / (id + '.jpg'))) for id in row.images.split(' ')[:n]])  # 5 ảnh result
    return r

def show_all(images, r=1, figsize=(20, 10)):
    # Hiển thị ảnh trên 1 hàng với 6 cột
    c = len(images)  # Số cột chính bằng số ảnh
    fig, axs = plt.subplots(r, c, figsize=figsize)  # Tạo lưới subplots với r hàng và c cột
    axs = axs.flatten()  # Chuyển thành mảng 1 chiều để dễ truy cập

    for i, img in enumerate(images):
        axs[i].imshow(img)
        axs[i].axis('off')

    plt.tight_layout()
    plt.show()

# Ví dụ sử dụng show_all với dataframe
show_all(image_results(sub_df.iloc[10]), r=1, figsize=(20, 10))
