In [None]:
# Import necessary libraries
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from collections import Counter
from sklearn.metrics import pairwise_distances
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
import torch
import torchvision
from torch.utils.data import DataLoader, TensorDataset
from lightly import loss, transforms
from lightly.data import LightlyDataset
from lightly.models.modules import heads
import ipywidgets as widgets
from IPython.display import display

# Define SimCLR model
class SimCLR(torch.nn.Module):
    def __init__(self, backbone):
        super().__init__()
        self.backbone = backbone
        self.projection_head = heads.SimCLRProjectionHead(
            input_dim=512,  # Adjust based on backbone output
            hidden_dim=512,
            output_dim=128,
        )

    def forward(self, x):
        features = self.backbone(x).flatten(start_dim=1)
        z = self.projection_head(features)
        return z

# Check if GPU is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Initialize model and move to GPU if available
backbone = torchvision.models.resnet18(pretrained=False)
backbone.fc = torch.nn.Identity()
model = SimCLR(backbone).to(device)

# Prepare dataset and dataloader
transform = transforms.SimCLRTransform(input_size=32)
dataset = LightlyDataset(input_dir="./data/cifar-10/train/", transform=transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=128, shuffle=True)

# Define loss and optimizer
criterion = loss.NTXentLoss(temperature=0.5).to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, weight_decay=1e-6)

# Training loop
for epoch in range(10):
    for (view0, view1), _, _ in dataloader:
        view0, view1 = view0.to(device), view1.to(device)
        z0 = model(view0)
        z1 = model(view1)
        loss = criterion(z0, z1)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        print(f"Loss: {loss.item():.5f}")

# Function to extract features
def extract_features(model, dataloader, device):
    model.eval()
    features = []
    with torch.no_grad():
        for batch in dataloader:
            inputs = batch[0].to(device)  # Assuming batch is a list of [images, labels]
            outputs = model(inputs)
            features.append(outputs.cpu().numpy())
    return np.vstack(features)

# Extract features
features = extract_features(model, dataloader, device)

# Visualize selected images
def visualize_selected_images(images, labels, selected_indices, image_size=(32, 32), grid_shape=(2, 5)):
    fig, axs = plt.subplots(nrows=grid_shape[0], ncols=grid_shape[1], figsize=(15, 6))
    axs = axs.flatten()
    for ax, idx in zip(axs, selected_indices):
        image = images[idx].reshape(image_size[0], image_size[1], 3)
        ax.imshow(image)
        ax.set_title(f"Label: {labels[idx]}")
        ax.axis('off')
    for ax in axs[len(selected_indices):]:
        ax.axis('off')
    plt.tight_layout()
    plt.show()

# k-Center Greedy algorithm for uncertainty sampling
def k_center_greedy(features, n_samples):
    n_points = features.shape[0]
    centers = [np.random.randint(n_points)]
    distances = pairwise_distances(features, features[centers]).flatten()
    for _ in range(1, n_samples):
        new_center = np.argmax(distances)
        centers.append(new_center)
        new_distances = pairwise_distances(features, features[new_center].reshape(1, -1)).flatten()
        distances = np.minimum(distances, new_distances)
    return centers

# Function to calculate uncertainty (using model predictions)
def calculate_uncertainty(model, dataloader, device):
    uncertainties = []
    model.eval()
    with torch.no_grad():
        for inputs, _ in dataloader:
            inputs = inputs.to(device)
            outputs = model(inputs)
            probs = torch.nn.functional.softmax(outputs, dim=1)
            uncertainty = -torch.max(probs, dim=1)[0]  # Use negative max probability as uncertainty
            uncertainties.extend(uncertainty.cpu().numpy())
    return np.array(uncertainties)

# Function to apply clustering and select uncertain samples using k-center greedy
def select_uncertain_samples(features, uncertainties, n_clusters, n_samples_per_cluster):
    kmeans = KMeans(n_clusters=n_clusters, random_state=0).fit(features)
    cluster_labels = kmeans.labels_
    selected_indices = []
    for cluster in range(n_clusters):
        cluster_indices = np.where(cluster_labels == cluster)[0]
        if len(cluster_indices) > 0:
            cluster_features = features[cluster_indices]
            cluster_uncertainties = uncertainties[cluster_indices]
            # Sort by uncertainty and select top n_samples_per_cluster
            sorted_indices = cluster_indices[np.argsort(-cluster_uncertainties)[:n_samples_per_cluster]]
            selected_indices.extend(sorted_indices)
    return selected_indices

# Example usage
# Assuming you have X_train, y_train, and a trained model
features = np.random.rand(1000, 64)  # Example feature data
uncertainties = np.random.rand(1000)  # Example uncertainties, replace with actual uncertainties
n_clusters = 10
n_samples_per_cluster = 5

selected_indices = select_uncertain_samples(features, uncertainties, n_clusters, n_samples_per_cluster)

# Interactive Labeling
def interactive_labeling(images, labels, selected_indices):
    label_widgets = [widgets.Dropdown(options=list(set(labels)), description=f'Label {i+1}', value=labels[idx]) for i, idx in enumerate(selected_indices)]
    button = widgets.Button(description="Submit Labels")
    output = widgets.Output()

    def on_button_clicked(b):
        with output:
            for i, widget in enumerate(label_widgets):
                labels[selected_indices[i]] = widget.value
            print("Updated labels:", extract_labels_at_indices(labels, selected_indices))

    button.on_click(on_button_clicked)
    display(*label_widgets, button, output)

# Visualize and label
visualize_selected_images(X_train, y_train, selected_indices)
interactive_labeling(X_train, y_train, selected_indices)
