In [132]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import fsspec
from xarray.backends import NetCDF4DataStore
import xarray as xr
import matplotlib.pyplot as plt
import os
import numpy as np

In [133]:

class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True))

    def forward(self, x):
        x = self.conv(x)
        return x


class InConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(InConv, self).__init__()
        self.conv = DoubleConv(in_ch, out_ch)

    def forward(self, x):
        x = self.conv(x)
        return x


class Down(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(Down, self).__init__()
        self.mpconv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_ch, out_ch)
        )

    def forward(self, x):
        x = self.mpconv(x)
        return x


class Up(nn.Module):
    def __init__(self, in_ch, out_ch, bilinear=True):
        super(Up, self).__init__()

        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        else:
            self.up = nn.ConvTranspose2d(in_ch // 2, in_ch // 2, 2, stride=2)

        self.conv = DoubleConv(in_ch, out_ch)

    def forward(self, x1, x2):
        x1 = self.up(x1)

        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, (diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2))
        x = torch.cat([x2, x1], dim=1)
        x = self.conv(x)
        return x


class OutConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, 1)

    def forward(self, x):
        x = self.conv(x)
        return x

class Unet(nn.Module):
    def __init__(self, in_channels, classes):
        super(Unet, self).__init__()
        self.n_channels = in_channels
        self.n_classes =  classes

        self.inc = InConv(in_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        self.down4 = Down(512, 512)
        self.up1 = Up(1024, 256)
        self.up2 = Up(512, 128)
        self.up3 = Up(256, 64)
        self.up4 = Up(128, 64)
        self.outc = OutConv(64, classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        x = self.outc(x)
        return x

In [134]:
class Fire_Dataset(torch.utils.data.Dataset):
    def __init__(self, root, mode="train", transform=None):

        assert mode in {"train", "valid", "test"}

        self.root = root
        self.mode = mode
        self.transform = transform

        self.images_directory = os.path.join(self.root, "images")
        self.masks_directory = os.path.join(self.root, "annotations")
 
        self.filenames = self._get_files(suffix = '.npy')  # read train/valid/test splits

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):

        filename = self.filenames[idx]
        
        image_path = os.path.join(self.images_directory, filename)
        mask_path = os.path.join(self.masks_directory, filename)
        
        image = torch.from_numpy(np.load(image_path))
        mask = torch.from_numpy(np.load(mask_path))
        
        sample = dict(image = image, mask = mask)
        
        

        return sample


    def _get_files(self, suffix):
        files = []
        
        if self.mode == 'train':
            dir_to_walk = os.path.join(self.root, 'images')
        elif self.mode == 'valid':
            dir_to_walk = os.path.join(self.root, 'annotations')
            
            
        for r,d,f in os.walk(dir_to_walk):
            for file in f:
                if file.endswith(suffix):
                    files.append(os.path.join(r,file))
        
        
        if self.mode == "train":  # 90% for train
            filenames = [x for i, x in enumerate(files) if i % 10 != 0]
        elif self.mode == "valid":  # 10% for validation
            filenames = [x for i, x in enumerate(files) if i % 10 == 0]

        filenames = [f.split('/')[-1] for f in filenames]
        return filenames
    
    def return_files(self):
        return self.filenames
    
    def _normalize(tensor):
        mean, std, var = torch.mean(tensor), torch.std(tensor), torch.var(tensor)

        return (tensor-mean)/std

bads = [0,11,16,17,18] somewhere these have to drop?

In [135]:
def preprocess(dataset):
    def pad_along_axis(array: np.ndarray, target_length: int, axis: int = 0) -> np.ndarray:

        print(array.shape)
        pad_size = target_length - array.shape[axis]

        if pad_size <= 0:
            return array

        npad = [(0, 0)] * array.ndim
        npad[axis] = (0, pad_size)

        return np.pad(array, pad_width=npad, mode='edge')
    
    bands = list(dataset.data_vars)
    print(len(bands))

    dataset = dataset.drop_vars('goes_imager_projection')
    dataset = dataset.drop_vars('spatial_ref')
    output = {}
    bands2 = list(dataset.data_vars)
    
    for b in bands2:
        if b != 'fire_bool':
            output[b] = pad_along_axis(dataset[b].values, 64, axis = 1)[:,:,:64].squeeze()
        else:
            output[b] = pad_along_axis(dataset[b].values, 64, axis = 0)[:,:64].squeeze()
    
    return output


In [136]:
def dice_metric(inputs, target):
    intersection = 2.0 * (target * inputs).sum()
    union = target.sum() + inputs.sum()
    if target.sum() == 0 and inputs.sum() == 0:
        return 1.0

    return intersection / union

def dice_loss(inputs, target):
    num = target.size(0)
    inputs = inputs.reshape(num, -1)
    target = target.reshape(num, -1)
    smooth = 1.0
    intersection = (inputs * target)
    dice = (2. * intersection.sum(1) + smooth) / (inputs.sum(1) + target.sum(1) + smooth)
    dice = 1 - dice.sum() / num
    return dice

def bce_dice_loss(inputs, target):
    dicescore = dice_loss(inputs, target)
    bcescore = nn.BCELoss()
    bceloss = bcescore(inputs, target)

    return bceloss + dicescore

In [137]:
def drop_bands(batch):
    t = batch['image'].float()
    
    return torch.cat(((t[:,1:11,:,:], t[:,12:16,:,:],t[:,19:,:,:])), axis=1)

In [138]:
def normalize(np_arr):
    normalizes = []
    for band in np_arr:
       
        mini, maxi = np.min(band), np.max(band)
        
        normalized = (band - mini) / (maxi - mini)
        normalizes.append(normalized)
            
    
    
    return normalizes

In [139]:

# Load the model
unet = Unet(21, 1)
unet.load_state_dict(torch.load('/Users/adamhunter/Documents/school projs/firenet/data/model2.pt'))

# Create a file system
fs = fsspec.filesystem('gs', project='firenet-99')

# List all .nc files in the bucket
files = fs.glob('preprocessed_firenet_input/*.nc')

# Initialize a dictionary to store the results
results = {}


In [140]:

# Process each file
for file in files[:100]:
    # Open the file
    with fs.open('gs://' + file) as f:
        ds = xr.open_dataset(f)

        # Extract spatial information
        spatial_info = ds.attrs.get('spatial_info')

        # Preprocess the data
        data = preprocess(ds)

        feature_names = list(data.keys())
        print(len(feature_names))
        bool_remove = [f.startswith('DQF_') for f in feature_names]
        other_remove = ['goes_imager_projection', 'spatial_ref']
        
        feature_names = [d for (d, remove) in zip(feature_names, bool_remove) if not remove]
        feature_names = [k for k in feature_names if k not in other_remove]
        indices_to_drop = [0, 11, 16, 17, 18]
        feature_names = [feature for idx, feature in enumerate(feature_names) if idx not in indices_to_drop]
        print(feature_names)
        features = np.array([data[f] for f in feature_names])
        features = np.stack(normalize(features))

        # Convert the data to PyTorch tensors
        input_data = torch.from_numpy(features)

        input_data = input_data.unsqueeze(0)  # Add an extra dimension for batch size

        # Run the model on the unseen data
        predictions = unet(input_data)
        # Store the predictions and spatial information in the results dictionary
        results[file] = {'predictions': predictions.detach().numpy(), 'spatial_info': spatial_info}


44
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
42
['CMI_C06', 'CMI_C07', 'CMI_C08', 'CMI_C09', 'CMI_C10', 'CMI_C11', 'CMI_C12', 'CMI_C13', 'CMI_C14', 'CMI_C15', 'CMI_C16', 'CMI_C03', 'CMI_C04', 'CMI_C05', 'feat_14_7', 'LC22_CBD_220', 'LC22_EVH_220', 'LC20_Elev_220', 'LC22_EVC_220', 'LC22_FVH_220', 'LC22_F40_220']
44
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64

  normalized = (band - mini) / (maxi - mini)


44
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
42
['CMI_C06', 'CMI_C07', 'CMI_C08', 'CMI_C09', 'CMI_C10', 'CMI_C11', 'CMI_C12', 'CMI_C13', 'CMI_C14', 'CMI_C15', 'CMI_C16', 'CMI_C03', 'CMI_C04', 'CMI_C05', 'feat_14_7', 'LC22_CBD_220', 'LC22_EVH_220', 'LC20_Elev_220', 'LC22_EVC_220', 'LC22_FVH_220', 'LC22_F40_220']
44
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64, 64)
(1, 64