In [1]:
import glob
import pickle
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from torchvision import transforms

import pandas as pd
from ekfplot import plot as ek, colors as ec
from ekfstats import math, fit, imstats

from pieridae.starbursts import sample

In [2]:
import torch
import torch.nn as nn
from byol_pytorch import BYOL
from torchvision import models

In [3]:
from tqdm import tqdm
from ekfstats import functions

In [4]:
filenames = glob.glob('../local_data/pieridae_output/starlet/starbursts_v0/M*/*i_results.pkl')

imgs = []
img_names = []
for fname in filenames:
    img = []
    for band in 'gi':
        current_filename = fname.replace('_i_',f'_{band}_')

        
        with open(current_filename,'rb') as f:
            xf = pickle.load(f)

            img.append(xf['image'])
            if band == 'i':
                img.append(xf['hf_image'])

    
    imgs.append(np.array(img))
    img_names.append(fname.split('/')[-2])
imgs = np.array(imgs)
img_names = np.array(img_names)

In [5]:
def sample_unlabelled_images():
    indices = np.random.permutation(len(imgs))
    return torch.tensor(imgs[indices], dtype=torch.float32)

In [6]:
# Using torchvision transforms (wrapped in nn.Sequential)
transform1 = nn.Sequential(
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(degrees=180),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0))
)

transform2 = nn.Sequential(
    transforms.RandomVerticalFlip(p=0.5),
    transforms.RandomRotation(degrees=180),
    transforms.ColorJitter(brightness=0.3, contrast=0.3),
    transforms.RandomApply([transforms.GaussianBlur(kernel_size=3)], p=0.3)
)

resnet = models.resnet18(pretrained=True)

learner = BYOL(
    resnet,
    image_size=150,
    hidden_layer='avgpool',
    projection_size=256,        # Final projection dimension
    projection_hidden_size=1024, # Hidden layer in projector MLP
    moving_average_decay=0.99,   # τ_base for shorter training
    use_momentum=True,
    augment_fn=transform1,
    augment_fn2=transform2
)



In [None]:
opt = torch.optim.Adam(learner.parameters(), lr=3e-4)

for _ in tqdm(range(20)):
    images = sample_unlabelled_images()
    loss = learner(images)
    opt.zero_grad()
    loss.backward()
    opt.step()
    learner.update_moving_average() # update moving average of target encoder

  0%|          | 0/20 [00:00<?, ?it/s]

In [None]:
projection, embedding = learner(torch.tensor(imgs, dtype=torch.float32), return_embedding = True)

In [None]:
from astropy.visualization import make_lupton_rgb

In [None]:
# Example: Find similar images using embeddings
from sklearn.metrics.pairwise import cosine_similarity

# Get embeddings for multiple images
embeddings = embedding.detach().numpy()

# Compute similarity matrix
similarity_matrix = cosine_similarity(embeddings)

# Find most similar image to the first one
#pairs = np.zeros([len(embeddings),2])
#for sidx in range(len(embeddings)):    
sidx = 301#np.where(img_names=='M3406229848245433130')[0][0]
most_similar_idx = similarity_matrix[sidx].argsort()[-2]  # -1 would be itself
#pairs[sidx] = [sidx,most_similar_idx]
print(f"Most similar image to image {sidx}: image {most_similar_idx}")

In [None]:
fig,axarr = plt.subplots(1,2,figsize=(10,4))
ek.imshow(imgs[sidx,1],ax=axarr[0], q=0.01)
ek.imshow(imgs[most_similar_idx,1],ax=axarr[1], q=0.01)

In [None]:
import umap

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import umap

def create_embeddings_umap(learner, images, n_components=2, 
                          n_neighbors=15, min_dist=0.1, metric='cosine',
                          random_state=42):
    """
    Extract embeddings from BYOL model and reduce dimensionality using UMAP.
    
    Args:
        learner: Trained BYOL model
        images: Input images tensor (batch_size, channels, height, width)
        n_components: Number of dimensions to reduce to (2 or 3)
        n_neighbors: UMAP parameter controlling local vs global structure (5-50)
                    Lower values preserve local structure, higher values preserve global structure
        min_dist: UMAP parameter controlling how tightly points are packed (0.001-0.5)
                 Lower values create tighter clusters
        metric: Distance metric ('cosine', 'euclidean', 'manhattan', etc.)
               'cosine' often works well for high-dimensional embeddings
        random_state: Random seed for reproducibility
    
    Returns:
        tuple: (reducer, embedding_reduced)
            - reducer: Fitted UMAP reducer object (can be used to transform new data)
            - embedding_reduced: Numpy array of reduced embeddings, shape (n_samples, n_components)
    
    Example:
        >>> reducer, embedding_2d = create_embeddings_umap(learner, imgs_tensor)
        >>> print(f"Reduced embeddings shape: {embedding_2d.shape}")
    """
    
    # Extract embeddings from BYOL model
    learner.eval()
    with torch.no_grad():
        _, embeddings = learner(images, return_embedding=True)
    
    # Convert to numpy and ensure proper dtype
    embeddings_np = embeddings.cpu().numpy().astype(np.float32)
    
    # Handle potential NaN or inf values
    embeddings_np = np.nan_to_num(embeddings_np, nan=0.0, posinf=0.0, neginf=0.0)
    
    # Ensure n_neighbors is valid (must be less than number of samples)
    n_neighbors = min(n_neighbors, len(embeddings_np) - 1)
    
    # Fit UMAP
    print(f"Fitting UMAP with {len(embeddings_np)} samples...")
    print(f"Original embeddings shape: {embeddings_np.shape}")
    
    reducer = umap.UMAP(
        n_components=n_components,
        n_neighbors=n_neighbors,
        min_dist=min_dist,
        metric=metric,
        random_state=random_state,
        verbose=False
    )
    
    embedding_reduced = reducer.fit_transform(embeddings_np)
    
    print(f"UMAP embedding shape: {embedding_reduced.shape}")
    
    return reducer, embedding_reduced


def visualize_embeddings(embedding_reduced, labels=None, ax=None, figsize=(10, 8), 
                        title=None, point_size=50, alpha=0.7, colormap='tab10',
                        save_path=None):
    """
    Create a scatter plot visualization of reduced embeddings.
    
    Args:
        embedding_reduced: Numpy array of reduced embeddings, shape (n_samples, 2 or 3)
        labels: Optional array of labels for coloring points, shape (n_samples,)
        figsize: Tuple of figure dimensions (width, height) in inches
        title: Custom title for the plot. If None, uses default title
        point_size: Size of scatter plot points
        alpha: Transparency of points (0.0 to 1.0)
        colormap: Matplotlib colormap name for coloring points when labels are provided
        save_path: Optional path to save the figure (e.g., 'plot.png', 'plot.pdf')
    
    Returns:
        tuple: (fig, ax) matplotlib figure and axis objects
    
    Example:
        >>> fig, ax = visualize_embeddings(embedding_2d, labels=cluster_labels)
        >>> plt.show()
        
        >>> # For 3D visualization
        >>> fig, ax = visualize_embeddings(embedding_3d, labels=labels)
        >>> plt.show()
    """
    
    # Determine if this is 2D or 3D
    is_3d = embedding_reduced.shape[1] == 3
    
    # Create figure
    if ax is None:
        if is_3d:
            fig = plt.figure(figsize=figsize)
            ax = fig.add_subplot(111, projection='3d')
        else:
            fig, ax = plt.subplots(figsize=figsize)
    
    # Create scatter plot
    if labels is not None:
        # Color by labels if provided
        if is_3d:
            scatter = ax.scatter(embedding_reduced[:, 0], embedding_reduced[:, 1], embedding_reduced[:, 2],
                               c=labels, cmap=colormap, alpha=alpha, s=point_size)
        else:
            scatter = ax.scatter(embedding_reduced[:, 0], embedding_reduced[:, 1], 
                               c=labels, cmap=colormap, alpha=alpha, s=point_size)
        
        # Add colorbar
        plt.colorbar(scatter, ax=ax, shrink=0.8 if is_3d else 1.0)
        
        # Default title with labels
        default_title = f'{"3D " if is_3d else ""}UMAP Visualization of BYOL Embeddings (Colored by Labels)'
    else:
        # Single color if no labels
        if is_3d:
            ax.scatter(embedding_reduced[:, 0], embedding_reduced[:, 1], embedding_reduced[:, 2],
                      alpha=alpha, s=point_size, c='grey')
        else:
            ax.scatter(embedding_reduced[:, 0], embedding_reduced[:, 1], 
                      alpha=alpha, s=point_size, c='grey')
        
        # Default title without labels
        default_title = f'{"3D " if is_3d else ""}UMAP Visualization of BYOL Embeddings'
    
    # Set title
    #ax.set_title(title if title is not None else default_title)
    
    # Set axis labels
    if is_3d:
        ax.set_xlabel('UMAP 1')
        ax.set_ylabel('UMAP 2')
        ax.set_zlabel('UMAP 3')
    else:
        ax.set_xlabel('UMAP 1')
        ax.set_ylabel('UMAP 2')
        ax.grid(True, alpha=0.3)
    
    # Save if path provided
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"Figure saved to: {save_path}")
    
    return ax

In [None]:
# Example usage:
#labels = np.repeat(np.arange(len(imgs)//20),20)

# Extract embeddings and reduce with UMAP
reducer, embedding_2d = create_embeddings_umap(
    learner=learner,
    images=torch.tensor(imgs, dtype=torch.float32),
    n_components=2,
    n_neighbors=15,
    min_dist=0.01,
    metric='cosine',
    random_state=42
)


In [None]:
mergers = pd.read_csv('./classifications_kadofong_20250925.csv', index_col=0)
# 1 undisturbed
# 2 ambiguous
# 3 merger
# 4 fragmentation
# 5 artifact

In [None]:
labels = mergers.reindex(img_names)
labels = labels.replace(np.nan, 0).values.flatten()

#cmap = ec.colormap_from_list(['lightgrey','C0','pink','r','tab:green','C4', 'k'], 'discrete')
cmap_1 = ec.colormap_from_list(['C0','tab:orange','r','tab:green','C4','k'], 'discrete')

names = {1:'undisturbed',2:'ambiguous',3:'merger',4:'fragmentation',5:'artifact'}

In [None]:
from matplotlib import gridspec

fig = plt.figure(figsize=(12,4))
ovlgrid = gridspec.GridSpec(2, 4)
ax1 = fig.add_subplot(ovlgrid[:,:2])
ax_clusters = [
    fig.add_subplot(ovlgrid[0,2]),
    fig.add_subplot(ovlgrid[0,3]),
    fig.add_subplot(ovlgrid[1,2]),
    fig.add_subplot(ovlgrid[1,3]),
]

# Visualize the results
visualize_embeddings(
    embedding_reduced=embedding_2d,
    #labels=labels,  # optional
    figsize=(12, 10),
    point_size=10,
    alpha=0.1,
    ax=ax1,
    #colormap=cmap
    #save_path='byol_embeddings.png'
)
ek.scatter(
    embedding_2d[labels>0,0],
    embedding_2d[labels>0,1],
    c=labels[labels>0],
    cmap=cmap_1,
    vmin=1,
    vmax=6,
    ax=ax1,
    s=6
)

cc = ['r','b','lime','magenta']

xdx = np.random.uniform(embedding_2d.min(axis=0),embedding_2d.max(axis=0))
xdx = [4.,6.5]
dist = np.sqrt((embedding_2d[:,0]-xdx[0])**2 + (embedding_2d[:,1]-xdx[1])**2)

sidx = 1
for ix,lbl in enumerate([1,2,3,4]):
    
    if sidx >= len(imgs[labels==lbl]):
        continue
        
    if lbl in [1,2,3,4]:
        ax = ax_clusters[ix]
        ek.imshow( imgs[labels==lbl][sidx][2], ax=ax )
        ek.text(0.025,0.975, names[lbl], color=cmap_1((lbl-1.)/5.), bordercolor='w', borderwidth=3, ax=ax)
        
    ax1.scatter(
        embedding_2d[labels==lbl][sidx,0],
        embedding_2d[labels==lbl][sidx,1],
        c=lbl,
        cmap=cmap_1,
        edgecolor='k',
        vmin=1,
        vmax=6,
    )
    

In [None]:
lbl