In [None]:
# Put these at the top of every notebook, to get automatic reloading and inline plotting
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
from fastai import *
from fastai.vision import *
import PIL

import pandas as pd
import numpy as np
import cv2
from pathlib import *
import colorcet as cc

cmap_grey = cc.cm.linear_grey_0_100_c0

In [1]:
#functions to convert the nn architecture to grayscale input
from fastai.core import *

def getGrayStats( imagenet_stats = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ):
    stats = imagenet_stats
    s=np.asarray(stats)
    st = []
    if len(s.shape)>=2 and s.shape[1] > 1:
        st.append( torch.from_numpy( np.asarray( np.mean(s[0]) ) ).float() )
        st.append( torch.from_numpy( np.asarray( np.sqrt( sum(s[1]*s[1]) / s.shape[1] ) ) ).float() ) 
    return st

def set_trainable(l, b):
    apply_leaf(l, lambda m: set_trainable_attr(m,b))

def set_trainable_attr(m,b):
    m.trainable=b
    for p in m.parameters(): p.requires_grad=b
        
def rgbModule2gray(module, rgb2gray=[0.299, 0.587, 0.114] ):
    #Take the average of the weights with: rgb2gray=[1., 1., 1.] ):
    #Or use the grayscale conversion from opencv:Y = 0.299 R + 0.587 G + 0.114 B
    n1,l_rgb = list(module.named_children())[0]
    #print("inputlayer:\nFilter size: ", l_rgb.weight.data.cpu().numpy().shape, "\nLayer definition: ",  str(l_rgb))
    rgb2gray = 3.0*np.asarray(rgb2gray) / sum(rgb2gray)
    #print(f"3*rgb2gray: {rgb2gray}")
          
    #create a 1 channel layer that is the sum of the three rgb channels
    conv_2d_gray = nn.Conv2d(1, out_channels=l_rgb.out_channels, kernel_size=l_rgb.kernel_size, 
                             stride=l_rgb.stride, padding=l_rgb.padding, bias=l_rgb.bias)

    #make np view on the weights and converet the rgb weights to gray weights for each pixel
    rgb_weight  = l_rgb.weight.data.cpu().numpy()
    gray_weight = conv_2d_gray.weight.data.cpu().numpy()
    strength    = np.zeros(rgb_weight.shape[0])
    for i in range( 0, rgb_weight.shape[0] ):
        #sum together the filter filter form each of the rgb channels with separate weighing of each channel
        
        s2 =  (rgb2gray[0]*rgb_weight[i,0] + rgb2gray[1]*rgb_weight[i,1] + rgb2gray[2]*rgb_weight[i,2] )
        #s1 = np.sum( rgb_weight[i], 0 ) 
        #print(f"s1.shape:{s1.shape} s2.shape:{s2.shape}")
        #print(f"s1:{s1[-2:,0]} \ns2:{s2[-2:,0]}")
        gray_weight[i] = s2
        strength[i]    = np.sum(np.abs(gray_weight[i]))

    #sort the filter so the stroomng filter ar placede first (for visualization purposes)
    ix_sort = np.argsort(-strength)
    conv_2d_gray.weight = torch.nn.Parameter( torch.from_numpy(gray_weight[ix_sort]) )

    #freeze the gray layer
    set_trainable(conv_2d_gray,False)            
            
    #extract all but the first layer
    m_children = list(module.children())[1:]
            
    #insert a the new gray first
    m_children.insert(0,conv_2d_gray)
    
    #return the module that takes a grayscale image as input
    module = nn.Sequential( *m_children )
    return module


#Small classs to intercept the instantiation of the model in cnn_create and convert the input filter to grayscale 
class Model2Grayscale:
    def __init__( self, arch ):
        self.arch = arch
    def __call__(self, pretrained): 
        module = self.arch(pretrained)
        module = rgbModule2gray(module)    
        return module    

In [None]:
#Class for reading 16 grascale images
class GrayImageDataset(ImageClassificationDataset):
    @staticmethod
    def create(path, dfData ): 
        return GrayImageDataset( fns = [path/dir_im/f  for f in dfData.fnImage.values],
                                 labels = dfData.classes.values )
    @staticmethod
    def pil2tensor(image)->TensorImage:
        "Convert PIL style `image` array to torch style image tensor."
        arr = torch.from_numpy(np.asarray(image))
        arr = arr.view(image.size[1], image.size[0], -1)
        return arr.permute(2,0,1)
    @staticmethod
    def open_image(fn:PathOrStr)->Image:
        x = PIL.Image.open(fn).convert('I')
        return Image(GrayImageDataset.pil2tensor(x).float().div_(65536.0))
                                          
    @abstractmethod
    def _get_x(self,i): 
        return GrayImageDataset.open_image(self. x[i])
    
train_ds = GrayImageDataset.create( path, tvData[tvData.purpose=="train"] )
valid_ds = GrayImageDataset.create( path, tvData[tvData.purpose=="test"] )