# data


In [None]:
import glob
import os
from tqdm import tqdm
import cv2
import numpy as np
import sys

# Add parent directory to Python path
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))


class CFG:
    # ============== comp exp name =============
    current_dir = './'
    segment_path = '../train_scrolls/clustering/'
    
    start_idx = 26
    in_chans = 10
    valid_chans = 10
    
    size = 224
    tile_size = 224
    stride = tile_size // 8

    train_batch_size = 12# 32
    valid_batch_size = 30
    lr = 5e-5
    # ============== model cfg =============
    scheduler = 'cosine'#'cosine', 'linear'
    epochs = 4
    shape = (1000,1000)
    
    # Change the size of fragments
    frags_ratio1 = ['frag','re']
    frags_ratio2 = ['s4','202']
    ratio1 = 2
    ratio2 = 2



def read_image_mask(fragment_id, CFG=None):
    """ 
    Reads a fragment image and its corresponding masks.    
    """
    images = []
    start_idx = CFG.start_idx 
    end_idx = start_idx + CFG.in_chans
    
    idxs = range(start_idx, end_idx)
    
    for i in tqdm(idxs):
        tif_path = os.path.join(CFG.segment_path, fragment_id, "layers", f"{i:02}.tif")
        jpg_path = os.path.join(CFG.segment_path, fragment_id, "layers", f"{i:02}.jpg")
        png_path = os.path.join(CFG.segment_path, fragment_id, "layers", f"{i:02}.png") 
        
        if os.path.exists(tif_path):
            image = cv2.imread(tif_path, 0)
        elif os.path.exists(jpg_path):
            image = cv2.imread(jpg_path, 0)
        else:
            image = cv2.imread(png_path, 0)

        image = cv2.resize(image, (CFG.shape[0],CFG.shape[1]), interpolation=cv2.INTER_AREA)

        image=np.clip(image,0,200)
        images.append(image)

    images = np.stack(images, axis=2)
    print(f" Shape of {fragment_id} segment: {images.shape}")
    
    # Label = first letter of the fragment folder name
    label = fragment_id[0]
    return images, label




In [81]:
import os
from tqdm import tqdm

def load_all_fragments(base_path, CFG):
    """
    Loads all image fragments and their labels from a directory.

    Args:
        base_path (str): Path to the folder containing fragment subfolders.
        CFG (object): Configuration object.

    Returns:
        imgs_list (list of np.ndarray): List of image stacks [H, W, C].
        labels_list (list of str): Corresponding first-letter labels.
    """
    imgs_list = []
    labels_list = []

    for f in tqdm(os.listdir(base_path), desc="Loading fragments"):
        folder_path = os.path.join(base_path, f)
        if os.path.isdir(folder_path):
            try:
                img, label = read_image_mask(f, CFG)
                imgs_list.append(img)
                labels_list.append(label)
            except Exception as e:
                print(f"Skipping {f} due to error: {e}")

    return imgs_list, labels_list


base_path = "../train_scrolls/clustering"
cfg = CFG()  # your configuration object

imgs, labels = load_all_fragments(base_path, cfg)

print(f"Loaded {len(imgs)} fragments.")
print("First labels:", labels[:10])


100%|██████████| 10/10 [00:00<00:00, 582.49it/s], ?it/s]


 Shape of A_1 segment: (1000, 1000, 10)
A


100%|██████████| 10/10 [00:00<00:00, 592.68it/s]


 Shape of R_1 segment: (1000, 1000, 10)
R


100%|██████████| 10/10 [00:00<00:00, 603.51it/s]


 Shape of R_2 segment: (1000, 1000, 10)
R


100%|██████████| 10/10 [00:00<00:00, 590.56it/s]


 Shape of A_2 segment: (1000, 1000, 10)
A


100%|██████████| 10/10 [00:00<00:00, 593.02it/s]
Loading fragments:  83%|████████▎ | 5/6 [00:00<00:00, 43.35it/s]

 Shape of A_3 segment: (1000, 1000, 10)
A


100%|██████████| 10/10 [00:00<00:00, 337.84it/s]
Loading fragments: 100%|██████████| 6/6 [00:00<00:00, 39.44it/s]

 Shape of W_1 segment: (1000, 1000, 10)
W
Loaded 6 fragments.
First labels: ['A', 'R', 'R', 'A', 'A', 'W']





In [80]:
from models.resnetall import generate_model
import torch

fe = generate_model(model_depth=101, n_input_channels=1,forward_features=False,n_classes=1039)
# self.backbone = unetr.MiniUNETR(img_shape=(16, 128, 128),output_dim=1,input_dim=1)
state_dict=torch.load('../checkpoints/r3d101_KM_200ep.pth')["state_dict"]
conv1_weight = state_dict['conv1.weight']
state_dict['conv1.weight'] = conv1_weight.sum(dim=1, keepdim=True)
fe.load_state_dict(state_dict,strict=False)

<All keys matched successfully>

In [None]:
import torch
import numpy as np
from sklearn.cluster import KMeans

# Suppose imgs is a list of numpy arrays [H, W, C] per fragment
# Convert to torch tensor and normalize
imgs_tensor = torch.stack([torch.tensor(im, dtype=torch.float32) for im in imgs])  # [B, H, W, C]
imgs_tensor = imgs_tensor.permute(0, 3, 1, 2).unsqueeze(1)  # [B, 1, C, H, W]
imgs_tensor = imgs_tensor / 255.0  # Normalize to [0,1]

# Send to device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
fe = fe.to(device)
imgs_tensor = imgs_tensor.to(device)

# Extract features
fe.eval()
with torch.no_grad():
    feats = fe(imgs_tensor)  # Suppose output shape [B, F, D, H, W]


# K-Means clustering
n_clusters = 3
kmeans = KMeans(n_clusters=n_clusters, random_state=0)
cluster_labels = kmeans.fit_predict(feats.cpu())

# Output results
for i, frag_label in enumerate(cluster_labels):
    print(f"Fragment {i} (label {labels[i]}): Cluster {frag_label}")


Fragment 0 (label A): Cluster 0
Fragment 1 (label R): Cluster 0
Fragment 2 (label R): Cluster 2
Fragment 3 (label A): Cluster 0
Fragment 4 (label A): Cluster 1
Fragment 5 (label W): Cluster 0


: 

In [73]:
imgs = torch.tensor(imgs)
imgs = imgs.to(torch.float32)
imgs.shape


torch.Size([6, 1000, 1000, 10])

In [None]:
import torch
from sklearn.cluster import KMeans

# (1) Prepare data
imgs = imgs.permute(0, 3, 1, 2).unsqueeze(1)  # [B, C, D, H, W]

# (2) Extract features
fe.eval()
with torch.no_grad():
    feats = fe(imgs)
    
# (4) Cluster
kmeans = KMeans(n_clusters=3)
labels = kmeans.fit_predict(feats)

RuntimeError: permute(sparse_coo): number of dimensions in the tensor input does not match the length of the desired ordering of dimensions i.e. input.dim() = 2 is not equal to len(dims) = 5

In [75]:
feats.shape

torch.Size([6, 1039])

In [None]:
# (4) Cluster
kmeans = KMeans(n_clusters=3)
labels = kmeans.fit_predict(feats)

# # (5) Reshape back
# labels_3d = labels.reshape(feats.shape[0], feats.shape[2], feats.shape[3], feats.shape[4])