In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

class MultispectralImage:
    def __init__(self, data, band_names):
        self.data = data  # a 3D numpy array with shape (height, width, num_bands)
        self.band_names = band_names  # a 1D numpy array of strings with length num_bands
        self.shape = data.shape[:2]  # a tuple containing the image height and width
        
    def get_band(self, band_name):
        band_idx = np.where(self.band_names == band_name)[0]
        if len(band_idx) == 0:
            raise ValueError(f"Band {band_name} not found")
        return self.data[:, :, band_idx[0]]
    
    def get_pixel(self, x, y):
        return self.data[y, x, :]
    
    def set_pixel(self, x, y, pixel_data):
        self.data[y, x, :] = pixel_data
    
    def add_band(self, new_band_data, band_name):
        if new_band_data.shape != (self.shape[0], self.shape[1]):
            raise ValueError("New band data must have the same dimensions as existing image data")
        self.data = np.dstack((self.data, new_band_data))
        self.band_names = np.append(self.band_names, band_name)
        
    def add_band_from_1d_array(self, new_band_data, band_name):
        if new_band_data.shape[0] != self.shape[0] * self.shape[1]:
            raise ValueError("New band data must have the same number of pixels as existing image data")
        new_band_data = new_band_data.reshape((self.shape[0], self.shape[1]))
        self.add_band(new_band_data, band_name)
    
    def remove_band(self, band_name):
        band_idx = np.where(self.band_names == band_name)[0]
        if len(band_idx) == 0:
            raise ValueError(f"Band {band_name} not found")
        self.data = np.delete(self.data, band_idx, axis=2)
        self.band_names = np.delete(self.band_names, band_idx)
    
    def to_dataframe(self):
        df_data = {}
        for i in range(len(self.band_names)):
            df_data[self.band_names[i]] = self.data[:, :, i].ravel()
        return pd.DataFrame(df_data)
    
    def plot_band(self, band_name):
        band_idx = np.where(self.band_names == band_name)[0]
        if len(band_idx) == 0:
            raise ValueError(f"Band {band_name} not found")
        plt.imshow(self.data[:, :, band_idx[0]], cmap='gray')
        plt.title(band_name)
        plt.show()
        
    def plot_all_bands(self):
        fig, axs = plt.subplots(nrows=1, ncols=len(self.band_names), figsize=(15, 5))
        for i, band_name in enumerate(self.band_names):
            axs[i].imshow(self.get_band(band_name), cmap='gray')
            axs[i].set_title(band_name)
        plt.show()



In [None]:
image = MultispectralImage()
inage.get_band()

TypeError: ignored