# N-Channeled-Input-UNet-Fastai

In [19]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline
%autosave 10

from fastai.vision import *
from fastai.utils.mem import *
from fastai.vision.learner import cnn_config # notice this extra import 
from torch.utils.data import Dataset
import torch.nn as nn

Autosaving every 10 seconds


In [2]:
torch.cuda.set_device(1) # comment this out 

## Data

Don't worry about this, you can use a pytorch dataset similar to this for reading the input 
images/ arrays, I would recommend using a pytorch dataloader for this purpose in this 
case because it might take time for you to figure out to get fastai datasets to work with 
n channeled datasets. You wont loose much functionality as anyway you can't plot your 
array/images using show_batch() method, If you want to make sure your input shape is correct
use `data.train_ds[0][0].shape` to check the shape.

In [3]:
class NChanneledDataset(nn.Module):
    def __init__(self, n_input_channels=5, n_output_channels=2):
        super().__init__()
        self.n_input_channels = n_input_channels
        self.c = n_output_channels
        
    def __len__(self):
        return 100 # just some dummy value
    
    def __getitem__(self, idx):
        #read your nchannled array, image
        x = torch.randn(self.n_input_channels, 224, 224)
        y = torch.randint(low=0, high=self.c, size=(224, 224))
        return x, y 

train_ds = NChanneledDataset()
valid_ds = NChanneledDataset()

data = DataBunch.create(train_ds=train_ds, valid_ds=valid_ds, bs=2)

data

## unet_learner function

This function has been copied from the source given in the docs [link](https://github.com/fastai/fastai/blob/c667c17d4b684ff174795748d747d9e7180d8b1e/fastai/vision/learner.py#L109). I have made a few changes in type casting for input paramters (won't make a difference). Also, I have changed how it gets the size dimensions, just to be safe. 
The major change is changing the first layer of the model before it goes into the Dynamic Unet creator. 

In [16]:
def unet_learner(data:DataBunch, arch:Callable, pretrained:bool=True, blur_final:bool=True,
                 norm_type:Optional[NormType]=NormType, split_on=None, blur:bool=False,
                 self_attention:bool=False, y_range=None, last_cross:bool=True,
                 bottle:bool=False, cut=None, **learn_kwargs:Any)->Learner:
    "Build Unet learner from `data` and `arch`."
    
    # I have defined size intentionally like this, so that it won't be a problem when 
    # the input is an image 
    size = next(iter(data.train_dl))[0].shape[-2:]
    n_input_channels = next(iter(data.train_dl))[0][0].size(0)    
        
    meta = cnn_config(arch)
    body = create_body(arch, pretrained, cut)
    
    # changing the first layer to suit our input
    if not n_input_channels == 3: 
        prev_layer = body[0]
        body[0] = nn.Conv2d(n_input_channels, prev_layer.out_channels, 
                      kernel_size=prev_layer.kernel_size, 
                      stride=prev_layer.stride, 
                      padding=prev_layer.padding, 
                      bias=prev_layer.bias)

    model = to_device(models.unet.DynamicUnet(body, n_classes=data.c, img_size=size, blur=blur, blur_final=blur_final,
          self_attention=self_attention, y_range=y_range, norm_type=norm_type, last_cross=last_cross,
          bottle=bottle), data.device)
    learn = Learner(data, model, **learn_kwargs)
    learn.split(ifnone(split_on, meta['split']))
    if pretrained and n_input_channels == 3: learn.freeze()
    apply_init(model[2], nn.init.kaiming_normal_)
    return learn

In [17]:
learn = unet_learner(data, models.resnet34) # instantiate a learner for testing

## testing

I have used fit one cycle function just to make sure it is trianing (although on dummy data). I have also used the `model` attribute of the learner to get the model out and test it on a dummy pytorch tensor. 

In [18]:
learn.fit_one_cycle(1) # dont care about the loss as the data is dummy. 

epoch,train_loss,valid_loss,time
0,-1.4397454606381264e+24,-3.0427967760266653e+24,00:09


In [10]:
model = learn.model

In [15]:
model(torch.randn(2, 5, 224, 224).cuda()).shape

torch.Size([2, 2, 224, 224])