In [None]:
import torch

class KMeans:
    def __init__(self, n_clusters, max_iter=100):
        self.n_clusters = n_clusters
        self.max_iter = max_iter
    
    def fit(self, X):
        # Initialize cluster centers randomly
        self.centers = X[torch.randperm(X.shape[0])[:self.n_clusters]]
        
        for i in range(self.max_iter):
            # Assign each point to the nearest cluster center
            distances = torch.cdist(X, self.centers)
            labels = torch.argmin(distances, dim=1)
            
            # Update cluster centers as the mean of the assigned points
            for j in range(self.n_clusters):
                mask = labels == j
                if mask.sum() > 0:
                    self.centers[j] = X[mask].mean(dim=0)
                    
    def predict(self, X):
        distances = torch.cdist(X, self.centers)
        labels = torch.argmin(distances, dim=1)
        return labels


In [None]:
import os
import torch
import torchvision.transforms as transforms
from PIL import Image
import numpy as np

# Define the path to your dataset
data_path = '/home/scco0002/F&E_DeepLearning_VS/F-E-Project_Part_2/dataset/mini_testdatensatz/train'

# Define the number of clusters
n_clusters = 3

# Define the transform to preprocess your images
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor()
])

# Create an empty list to store your preprocessed images
image_list = []

# Loop over your dataset and preprocess each image
for img_name in os.listdir(data_path):
    # Load the image
    img_path = os.path.join(data_path, img_name)
    img = Image.open(img_path)
    
    # Apply the transform
    img_tensor = transform(img)
    
    # Add the preprocessed image to the list
    image_list.append(img_tensor.numpy())
    
# Create a data tensor from the preprocessed images
data_tensor = torch.from_numpy(np.array(image_list)).float()

# Flatten the tensor to a 2D array
data_array = data_tensor.view(data_tensor.shape[0], -1)

# Perform K-means clustering on the flattened data
kmeans = KMeans(n_clusters=n_clusters)
kmeans.fit(data_array)

# Assign a label to each pixel based on the cluster it belongs to
labels = kmeans.labels_.reshape(data_tensor.shape[:2])

# Visualize the segmented image
import matplotlib.pyplot as plt
plt.imshow(labels)
plt.show()
