## Making Model Prediction masks

- Get the prediction mask for a specific city tile using the `.pt` model file and the `model class`
- Load the specific city from the dataloader

### Data Loader

In [None]:
import os
import rasterio
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np

class ChangeDetectionDatasetTIF(Dataset):
    def __init__(self, t2019_dir, t2024_dir, mask_dir,classes, transform=None):
        self.t2019_dir = t2019_dir
        self.t2024_dir = t2024_dir
        self.mask_dir = mask_dir
        self.classes = classes  # Change detection classes
        self.transform = transform

        # Load all paths
        self.t2019_paths = sorted([f for f in os.listdir(t2019_dir) if f.endswith('.tif')])
        self.t2024_paths = sorted([f for f in os.listdir(t2024_dir) if f.endswith('.tif')])
        self.mask_paths = sorted([f for f in os.listdir(mask_dir) if f.endswith('.tif')])

    def __len__(self):
        return len(self.t2019_paths)

    def __getitem__(self, index):
        # Load images using rasterio
        with rasterio.open(os.path.join(self.t2019_dir, self.t2019_paths[index])) as src:
            img_t2019 = src.read(out_dtype=np.float32) / 255.0
        with rasterio.open(os.path.join(self.t2024_dir, self.t2024_paths[index])) as src:
            img_t2024 = src.read(out_dtype=np.float32) / 255.0
        # Load masks
        with rasterio.open(os.path.join(self.mask_dir, self.mask_paths[index])) as src:
            cd_mask = src.read(1).astype(np.int64)

        # Convert to PyTorch tensors
        img_t2019 = torch.from_numpy(img_t2019)
        img_t2024 = torch.from_numpy(img_t2024)
        cd_mask = torch.from_numpy(cd_mask)

        # Apply transforms if any
        if self.transform is not None:
            img_t2019 = self.transform(img_t2019)
            img_t2024 = self.transform(img_t2024)

        return img_t2019, img_t2024, cd_mask

def describe_loader(loader_type):
    img2019, img2024, cd_mask = next(iter(loader_type))
    print("Batch size:", loader_type.batch_size)
    print("2019 Image Shape:", img2019.shape)
    print("2024 Image Shape:", img2024.shape)
    print("Change Mask Shape:", cd_mask.shape)
    print("Number of images:", len(loader_type.dataset))
    print("Classes:", loader_type.dataset.classes)
    print("Unique CD values:", torch.unique(cd_mask))

# Example usage:
ROOT_DIRECTORY = "ChangeDetectionMergedDividedSplit-tif3"
SAVING_DIR = "image_saver"
CD_DIR = "cd2_Output"
#CLASSES = ['no_change','vegetation_increase','vegetation_decrease']
CLASSES = ['no_change', 'water_building', 'water_sparse', 'water_dense',
           'building_water', 'building_sparse', 'building_dense',
           'sparse_water', 'sparse_building', 'sparse_dense',
           'dense_water', 'dense_building', 'dense_sparse']

# Create datasets
train_dataset = ChangeDetectionDatasetTIF(
    t2019_dir=f"{ROOT_DIRECTORY}/train/Images/T2019",
    t2024_dir=f"{ROOT_DIRECTORY}/train/Images/T2024",
    mask_dir=f"{ROOT_DIRECTORY}/train/{CD_DIR}",
    classes=CLASSES
)

val_dataset = ChangeDetectionDatasetTIF(
    t2019_dir=f"{ROOT_DIRECTORY}/val/Images/T2019",
    t2024_dir=f"{ROOT_DIRECTORY}/val/Images/T2024",
    mask_dir=f"{ROOT_DIRECTORY}/val/{CD_DIR}",
    classes=CLASSES
)

test_dataset = ChangeDetectionDatasetTIF(
    t2019_dir=f"{ROOT_DIRECTORY}/test/Images/T2019",
    t2024_dir=f"{ROOT_DIRECTORY}/test/Images/T2024",
    mask_dir=f"{ROOT_DIRECTORY}/test/{CD_DIR}",
    classes=CLASSES
)

# Create dataloaders
### KEEP SHUFFLE=FALSE (to get same sample index each time)
num_workers = 8
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)#,num_workers=num_workers, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)#,num_workers=num_workers, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)#,num_workers=num_workers, pin_memory=True)

print("------------Train-----------")
describe_loader(train_loader)
print("------------Val------------")
describe_loader(val_loader)
print("------------Test------------")
describe_loader(test_loader)

### Model Definition

In [18]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import segmentation_models_pytorch as smp
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix

class Strategy3Model:
    """Combined CD and LCM model with checkpoint management"""
    def __init__(self, cd_architecture='unet', lcm_architecture='unet',
                 cd_encoder='resnet34', lcm_encoder='resnet34',
                 input_channels=3, num_classes=13, num_semantic_classes=4):
        # Initialize CD model
        self.cd_model = self._create_cd_model(
            architecture=cd_architecture,
            encoder=cd_encoder,
            input_channels=input_channels
        )
        # Initialize LCM model
        self.lcm_model = self._create_lcm_model(
            architecture=lcm_architecture,
            encoder=lcm_encoder,
            input_channels=input_channels,
            num_semantic_classes=num_semantic_classes
        )

    def _create_cd_model(self, architecture, encoder, input_channels):
        """Create binary change detection model"""
        if architecture.lower() == 'unet':
            model = smp.Unet(
                encoder_name=encoder,
                encoder_weights='imagenet',
                in_channels=input_channels*2,  # Concatenated images
                classes=1,  # Binary output,
                encoder_depth=4,  # Reduce depth (def=5)
                decoder_channels=(256, 128, 64, 32)  # Reduce channels(def=(256, 128, 64, 32, 16))

            )
        elif architecture.lower() == 'deeplabv3plus':
            model = smp.DeepLabV3Plus(
                encoder_name=encoder,
                encoder_weights='imagenet',
                in_channels=input_channels*2,
                classes=1,
            )
        # Add more architectures as needed
        return model

    def _create_lcm_model(self, architecture, encoder, input_channels, num_semantic_classes=4):
        """Create land cover mapping model"""
        if architecture.lower() == 'unet':
            model = smp.Unet(
                encoder_name=encoder,
                encoder_weights='imagenet',
                in_channels=input_channels,
                classes=num_semantic_classes,  # 4 land cover classes
            )
        elif architecture.lower() == 'deeplabv3plus':
            model = smp.DeepLabV3Plus(
                encoder_name=encoder,
                encoder_weights='imagenet',
                in_channels=input_channels,
                classes=num_semantic_classes,
            )
        # Add more architectures as needed
        return model

    def to(self, device):
        """Move models to device"""
        self.cd_model = self.cd_model.to(device)
        self.lcm_model = self.lcm_model.to(device)
        return self

    def train(self):
        """Set models to training mode"""
        self.cd_model.train()
        self.lcm_model.train()

    def eval(self):
        """Set models to evaluation mode"""
        self.cd_model.eval()
        self.lcm_model.eval()

def create_semantic_change_mask(binary_pred, lcm_pred_2019, lcm_pred_2024):
    """Convert binary change + LCM predictions to 13-class semantic change mask.

    Optimized version using vectorized operations and pre-computed lookup tables.

    Args:
        binary_pred: Binary change prediction tensor (B, 1, H, W)
        lcm_pred_2019: Land cover prediction tensor for 2019 (B, C, H, W)
        lcm_pred_2024: Land cover prediction tensor for 2024 (B, C, H, W)

    Returns:
        Semantic change mask tensor (B, H, W) with values 0-12
    """
    device = binary_pred.device
    batch_size = binary_pred.shape[0]
    height = binary_pred.shape[2]
    width = binary_pred.shape[3]

    # Pre-compute land cover predictions - do this once
    lcm_2019 = torch.argmax(lcm_pred_2019, dim=1)  # (B, H, W)
    lcm_2024 = torch.argmax(lcm_pred_2024, dim=1)  # (B, H, W)

    # Create the change mask - use threshold without squeeze/unsqueeze
    change_mask = binary_pred[:, 0] > 0.5  # (B, H, W)

    # Initialize output tensor
    semantic_mask = torch.zeros((batch_size, height, width), device=device, dtype=torch.long)

    # Create transition matrix lookup table - speeds up class mapping
    # Format: from_class * num_classes + to_class = semantic_class
    num_classes = 4  # Water, Building, Sparse, Dense
    transitions = torch.full((num_classes * num_classes,), 0, device=device)

    # Populate transition matrix - all transitions not listed default to 0 (no change)
    transition_map = {
        (0, 1): 1,   # Water → Building
        (0, 2): 2,   # Water → Sparse
        (0, 3): 3,   # Water → Dense
        (1, 0): 4,   # Building → Water
        (1, 2): 5,   # Building → Sparse
        (1, 3): 6,   # Building → Dense
        (2, 0): 7,   # Sparse → Water
        (2, 1): 8,   # Sparse → Building
        (2, 3): 9,   # Sparse → Dense
        (3, 0): 10,  # Dense → Water
        (3, 1): 11,  # Dense → Building
        (3, 2): 12,  # Dense → Sparse
    }

    for (from_idx, to_idx), semantic_idx in transition_map.items():
        transitions[from_idx * num_classes + to_idx] = semantic_idx

    # Vectorized computation of semantic classes
    # Only compute for changed pixels to save memory
    changed_pixels = change_mask.nonzero(as_tuple=True)
    if len(changed_pixels[0]) > 0:
        from_classes = lcm_2019[changed_pixels]  # (N,)
        to_classes = lcm_2024[changed_pixels]    # (N,)

        # Compute transition indices
        transition_indices = from_classes * num_classes + to_classes  # (N,)

        # Look up semantic classes from transition matrix
        semantic_classes = transitions[transition_indices]  # (N,)

        # Assign semantic classes to output mask
        semantic_mask[changed_pixels] = semantic_classes

    return semantic_mask

### Define the checkpoint path, model name etc

#### Strategy 1: PCC - Linknet

In [None]:
# # Initialize model and device
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# architecture = 'linknet'  # 'unet' or 'linknet', 'pspnet', 'deeplabv3plus'
# num_classes = 13   # Change Detection classes (3 for cd1, 13 for cd2)
# num_semantic_classes = 4   # Semantic segmentation LCM classes (4 for both)
# num_epochs = 100
# loss = 'CE'
# checkpoint_path = f'{SAVING_DIR}/best_{architecture}_{num_epochs}_epochs.pt'

# # Create model
# model = ChangeDetectionModel(
#     architecture=architecture,encoder='resnet34',
#     input_channels=3,num_classes=num_classes,
#     num_semantic_classes=num_semantic_classes
# ).to(device)

#### Strategy 2: Siam Diff

In [None]:
# # Initialize and train the model
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model_name = 'siamunet_diff'
# strategy = 'st2' #change detection strategy {1,2,3,4}
# num_classes = 13  #num classes in change mask
# num_epochs = 100
# weighting_method = 'square_balanced' #'custom'
# loss = 'CE' #'focal' #'bcl'
# #checkpoint_path = f'{SAVING_DIR}/best_{strategy}_{model_name}_{num_epochs}.pt'
# checkpoint_path = f'{SAVING_DIR}/best_{strategy}_{model_name}-{num_classes}_classes_{num_epochs}.pt'

# model = SiamUnet_diff(input_nbr=3, label_nbr=num_classes).to(device)

#### Strategy 3

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_classes=13   #Change Detection classes (3 for cd1, 13 for cd2)
num_semantic_classes=4  #Semantic Segmentation LCM classes (4 for both)
num_epochs = 100
weighting_method = 'square_balanced'
loss = 'CE' #'CE'
checkpoint_path = f'{SAVING_DIR}/best_Strat3_{num_epochs}_epochs.pt'  #'models/strategy3_model.pt'

# Create model
model = Strategy3Model(
    cd_architecture='unet',
    lcm_architecture='unet',
    cd_encoder='resnet34',
    lcm_encoder='resnet34',
    input_channels=3,
    num_classes=num_classes,
    num_semantic_classes=num_semantic_classes
).to(device)

#### Strategy 4

In [None]:
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# # Model configuration
# input_channels = 3
# num_semantic_classes = 4 #len(SEMANTIC_CLASSES)
# num_change_classes = 13 #len(CLASSES)
# num_classes = num_change_classes
# num_epochs = 100

# # Create model
# model = MultiTaskChangeDetectionModel(
#     input_channels=input_channels,
#     num_semantic_classes=num_semantic_classes,
#     num_change_classes=num_change_classes
# ).to(device)

# # Define checkpoint paths
# lcm_checkpoint_path = f'{SAVING_DIR}/best_lcm_model_{num_epochs}.pt'
# full_checkpoint_path = f'{SAVING_DIR}/best_full_model_{num_epochs}.pt'

### Model Testing

In [None]:
import torch
import matplotlib.pyplot as plt
import numpy as np
import os

def find_sample_idx_by_filename(test_loader, filename):
    """
    Find the sample index for a given filename in the test_loader dataset.
    
    Args:
        test_loader: DataLoader containing the test data
        filename: Filename to search for (without path)
    
    Returns:
        int: Index of the sample, or -1 if not found
    """
    # Access the underlying dataset
    dataset = test_loader.dataset
    
    # Search through all paths in the dataset
    for idx in range(len(dataset)):
        # Get the t2019 paths from your dataset (adjust this based on your dataset structure)
        current_path = dataset.t2019_paths[idx]
        current_filename = os.path.basename(current_path)
        
        # Check if this is the file we're looking for
        if filename in current_filename:
            return idx
    
    return -1

def visualize_single_sample(model, test_loader, sample_idx, device='cpu', 
                          num_classes=3, save_path=None):
    """
    Visualize prediction for a specific sample from the test_loader.
    
    Args:
        model: The model to use for prediction
        test_loader: DataLoader containing the test data
        sample_idx: Index of the sample to visualize
        device: Device to run model on
        num_classes: Number of classes in change detection
        save_path: Optional path to save the visualization
    """
    ########### STRATEGY 1,2 ################
    # checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
    # model.load_state_dict(checkpoint['model_state_dict'])
    # print(f"Loaded checkpoint from {checkpoint_path}")
    
    ########### STRATEGY 3 ################
    # checkpoint = torch.load(checkpoint_path, map_location=device)
    # print(f"Loading checkpoint from {checkpoint_path}")
    # model.cd_model.load_state_dict(checkpoint['cd_model'])
    # model.lcm_model.load_state_dict(checkpoint['lcm_model'])
    # print("Successfully loaded both CD and LCM models")

    ########### STRATEGY 4 ################
    # checkpoint_path=full_checkpoint_path
    # checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
    # model.load_state_dict(checkpoint['model_state_dict'])
    # print(f"Loaded checkpoint from {checkpoint_path}")

    model.eval()

    # Get the specific sample
    total_batches = len(test_loader)
    batch_size = test_loader.batch_size
    batch_idx = sample_idx // batch_size
    item_idx = sample_idx % batch_size

    if batch_idx >= total_batches:
        print(f"Sample index {sample_idx} is out of range. Maximum index is {total_batches * batch_size - 1}")
        return

    # Get the specific batch
    for i, (inputs1, inputs2, labels) in enumerate(test_loader):
        if i == batch_idx:
            break

    # Extract the specific item from the batch
    img1 = inputs1[item_idx]
    img2 = inputs2[item_idx]
    true_mask = labels[item_idx]

    # Get prediction
    with torch.no_grad():
        # Add batch dimension for model
        img1_batch = img1.unsqueeze(0).to(device)
        img2_batch = img2.unsqueeze(0).to(device)

        ########### STRATEGY 1,2 ################
        #output = model(img1_batch, img2_batch)
        #pred_mask = torch.argmax(output, dim=1)[0].cpu().numpy()

        ########### STRATEGY 4 ################
        #_,_,output = model(img1_batch, img2_batch)
        #pred_mask = torch.argmax(output, dim=1)[0].cpu().numpy()

        ########### STRATEGY 3 ################
        cd_pred = model.cd_model(torch.cat([img1.unsqueeze(0), img2.unsqueeze(0)], dim=1))
        lcm_pred_2019 = model.lcm_model(img1.unsqueeze(0))
        lcm_pred_2024 = model.lcm_model(img2.unsqueeze(0))
        semantic_pred = create_semantic_change_mask(cd_pred, lcm_pred_2019, lcm_pred_2024)
        pred_mask = semantic_pred.squeeze(0)
        

    # Create visualization
    fig, axes = plt.subplots(1, 4, figsize=(20, 5))
    plt.subplots_adjust(wspace=0.3)

    # Display original images
    img1_display = img1.numpy().transpose(1, 2, 0)
    img2_display = img2.numpy().transpose(1, 2, 0)

    axes[0].imshow(img1_display)
    axes[0].set_title('Image 1')
    axes[0].axis('off')
    plt.imsave(f"{save_path}_img2019.png", img1_display)

    axes[1].imshow(img2_display)
    axes[1].set_title('Image 2')
    axes[1].axis('off')
    plt.imsave(f"{save_path}_img2024.png", img2_display)
    
    # 3 classes
    #colors = ['black', 'green', 'red']
    
    # 13 classes
    colors = ['black','lightgray', 'gray', 'darkgray',
              'darkblue', 'green', 'darkgreen',
              'lightblue', 'orange', 'lightgreen',
              'blue', 'orangered', 'peachpuff']

    cmap = plt.matplotlib.colors.ListedColormap(colors)

    # Plot predicted and true masks
    axes[2].imshow(pred_mask, cmap=cmap, vmin=0, vmax=num_classes-1)
    axes[2].set_title('Predicted Change')
    axes[2].axis('off')
    plt.imsave(f"{save_path}_CDprediction.png", pred_mask, cmap=cmap, vmin=0, vmax=num_classes-1)

    axes[3].imshow(true_mask, cmap=cmap, vmin=0, vmax=num_classes-1)
    axes[3].set_title('Ground Truth')
    axes[3].axis('off')
    plt.imsave(f"{save_path}_CDtruth.png", true_mask, cmap=cmap, vmin=0, vmax=num_classes-1)

    plt.show()

# Example usage:

# Find sample by filename and visualize
filename = 'NorthCarolina_Charlotte_W_2019_2.tif'
editedfilename = filename.replace('.tif','')
display_loader = test_loader  # Using image from test_loader for visualization
model_name = 'strategy4'
rooting_dir = f"{SAVING_DIR}/{model_name}_{num_classes}-classes_{editedfilename}"

if not os.path.exists(rooting_dir):
    os.mkdir(rooting_dir)

sample_idx = find_sample_idx_by_filename(display_loader, filename)

if sample_idx >= 0:
    visualize_single_sample(model, display_loader, sample_idx, 
                            device=device, num_classes=num_classes,
                            save_path=f"{rooting_dir}/{model_name}_{num_classes}-classes_{editedfilename}_")
else:
    print(f"File {filename} not found in dataset.")

## Finding good city tiles

### Random Image loader
- Used for finding good images (for paper)
- The `sample_index` is written in image title (used to find the name of the city)

In [None]:
import matplotlib.pyplot as plt
import random

# Set up the plot size and remove axes
fig, axs = plt.subplots(4, 3, figsize=(8,8))

for i in range(2):
    j = random.randint(0, len(test_dataset) - 1)
    # `j` is the sample_index (can be used further)
    image1, image2, mask = test_dataset[j]
    # Display images
    axs[i, 0].imshow(image1.permute(1, 2, 0))
    axs[i, 0].set_title(f"Real 2019 {j}")
    axs[i, 0].axis("off")

    axs[i, 1].imshow(image2.permute(1, 2, 0))
    axs[i, 1].set_title(f"Real 2024")
    axs[i, 1].axis("off")

    axs[i, 2].imshow(mask, cmap="turbo")
    axs[i, 2].set_title(f"CD Mask")
    axs[i, 2].axis("off")

plt.show()

### Retrive filename from `index`
- use the sample index from above to get the filename from dataset

In [None]:
import os

loader = 'test'
image_dir = f"ChangeDetectionMergedDividedSplit-tif3/{loader}/Images/T2019"

# Get sorted list of filenames
image_filenames = sorted(os.listdir(image_dir))

# Define the index
idx = 490 # Change to your required index

# Retrieve filename using index
filename = image_filenames[idx]
print(f"Filename at index {idx}: {filename}")

## Generating images, masks for paper

### Convert entire TIF folder to PNG

In [None]:
import os
import rasterio
import numpy as np
from PIL import Image

def tif_to_png(tif_path, png_path):
    # Read the .tif file using rasterio
    with rasterio.open(tif_path) as src:
        # Read the image data into a NumPy array
        array = src.read()
        if array.shape[0] == 3:  # RGB image
            array = np.moveaxis(array, 0, -1)  # Reorder dimensions to (H, W, C)
        elif array.shape[0] == 1:  # Grayscale image
            array = array[0]  # Remove the single-band dimension
    
    # Normalize the array to range [0, 255] for saving as PNG
    array = array - array.min()
    array = (array / array.max() * 255).astype(np.uint8)
    
    # Save the NumPy array as a .png image using Pillow
    img = Image.fromarray(array)
    img.save(png_path)
    print(f"Saved {png_path}")

def convert_all_tifs_in_folder(input_folder, output_folder):
    # Ensure the output folder exists
    os.makedirs(output_folder, exist_ok=True)
    
    # List all .tif files in the input folder
    tif_files = [f for f in os.listdir(input_folder) if f.endswith('.tif')]
    
    if not tif_files:
        print(f"No .tif files found in {input_folder}")
        return

    # Convert each .tif file to .png and save in the output folder
    for tif_file in tif_files:
        tif_path = os.path.join(input_folder, tif_file)
        png_path = os.path.join(output_folder, tif_file.replace(".tif", ".png"))
        tif_to_png(tif_path, png_path)

# Example usage:
directory1 = "./Organized/SouthDakota_SiouxFalls_E"
directory2 = "./Organized/SouthDakota_SiouxFalls_E"
convert_all_tifs_in_folder(directory1, directory2)

### Display and save the image
- Directly from `TIF` to `PNG`

In [None]:
import matplotlib.pyplot as plt
import rasterio

# Load the .tif image
tif_path = "Organized\SouthDakota_SiouxFalls_E\cd2_m_SouthDakota_SiouxFalls_E.tif"

with rasterio.open(tif_path) as src:
    image = src.read(1)  # Read the first band (for grayscale images)

#colors = ['lightblue', 'white', 'lightgreen','darkgreen']  # Seg mask (all 4 classes present)
#colors = ['white', 'lightgreen','darkgreen']  # Seg mask (3 classes)
#colors = ['black', 'green','red']  # cd1 - MCD mask (3 classes)
colors = ['black','lightgray', 'gray', 'darkgray',   # cd2 - SCD mask (13 classes)
              'darkblue', 'green', 'darkgreen',
              'lightblue', 'orange', 'lightgreen',
              'blue', 'orangered', 'peachpuff']
cmap = plt.matplotlib.colors.ListedColormap(colors)

# Plot the image with a user-defined colormap
plt.figure(figsize=(4, 4))
plt.imshow(image, cmap=cmap)

######## Adding a legend ########
#import matplotlib.patches as mpatches   
# Define the class labels
# class_labels = ['0: no change', '1: water to building', '2: water to sparse', '3: water to dense',
#                 '4: building to water', '5: building to sparse', '6: building to dense',
#                 '7: sparse to water', '8: sparse to building', '9: sparse to dense',
#                 '10: dense to water', '11: dense to building', '12: dense to sparse']

# Create a list of patches to be shown in the legend
#patches = [mpatches.Patch(color=colors[i], label=class_labels[i]) for i in range(len(class_labels))]

# Add the legend to the plot
#plt.legend(handles=patches, bbox_to_anchor=(1.05, 1), ncol=4,loc='upper left', borderaxespad=0., prop={'family': 'Times New Roman'})
#fig = plt.gcf()
#fig.subplots_adjust(right=0.8)
#fig.set_size_inches(1, 10)

plt.axis("off")
plt.imsave(tif_path.replace('.tif', '_colormap.png'), image, cmap=cmap)

plt.show()

## Plotting State vs Num Cities graph

### List all filenames in given folder
- useful for counting cities 

In [None]:
import os

def get_filenames(folder_path):
    """Returns a list of filenames in the given folder."""
    return [f for f in os.listdir(folder_path) if os.path.isfile(os.path.join(folder_path, f))]

# Example usage
folder_path = "./new_eur"  # Change this to your folder path
filenames = get_filenames(folder_path)
#print(type(filenames))


filenames2 = list(filenames)
filenames3 = []
for file in filenames2: 
    # file = os.path.splitext(os.path.basename(root_directory+file1))[0]
    # print(type(file))
    file  = file.replace("cd1_m_","")
    file = file.replace(".tif", "")
    # file = file.replace("_1","")
    # file = file.replace("_2","")
    # file = file.replace("_3","")
    # file = file.replace("_4","")
    file = file[:-2]
    if(file[-1]=="_"):
        file = file[:-1]
    filenames3.append(file)
    # print(file)

filenames4 = list(set(filenames3))
filenames4.sort()

print(len(filenames4))
# for party in filenames4:
#     print(party)

### Replace some filenames

In [None]:
replacements = {
    'Alabama_Dayton': 'Ohio_Dayton',
    'Alabama_Cincinnati': 'Ohio_Cincinnati',
    'Alabama_Toledo': 'Ohio_Toledo',
    'Illinois_Greenbay': 'Wisconsin_Greenbay',
    'Illinois_FortWayne': 'Indiana_FortWayne',
    'Illinois_SouthBend': 'Indiana_SouthBend'
}

# Apply replacements
filenames4_replaced = [replacements.get(name, name) for name in filenames4]

#print(filenames4_replaced)

filenames4_replaced.sort()
# for party in filenames4_replaced:
#     if party == "Illinois_Greenbay":
#         print(party)
#         print("bla")

print(len(filenames4_replaced))

state_city_split = [name.split('_') for name in filenames4_replaced]
print(state_city_split)


### Plot bar graph for state vs num of cities

In [None]:
from collections import Counter

import matplotlib.pyplot as plt

# Count the number of cities per state
state_counts = Counter([state for state, city in state_city_split])

# Extract states and their corresponding counts
states = list(state_counts.keys())
city_counts = list(state_counts.values())

print(len(states))

# Plot the bar chart
plt.figure(figsize=(15, 7))
bars = plt.bar(states, city_counts, color='#9198db')
plt.xlabel('Countries')
plt.ylabel('Number of Cities')
plt.title('Number of Cities per Country in Europe')
plt.xticks(rotation=90)
plt.show()
