Notebook for STM double tip simulation

In [None]:
import torch
import torch.nn as nn
import numpy as np
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms.functional as F
import os
import random
from scipy.ndimage import gaussian_filter, median_filter

In [None]:
def create_double_gaussian_kernel(size, peak1, peak2, sigma1, sigma2):
    kernel = np.zeros((size, size))
    
    # Create the first Gaussian peak
    kernel[peak1] = 1
    kernel = gaussian_filter(kernel, sigma=sigma1)
    
    # Create the second Gaussian peak
    kernel2 = np.zeros((size, size))
    kernel2[peak2] = 1
    kernel2 = gaussian_filter(kernel2, sigma=sigma2)
    
    # Combine the two Gaussians
    combined_kernel = kernel + kernel2
    
    return combined_kernel

# Define the kernel size and Gaussian parameters
kernel_size = 16
peak1 = (4, 4)  # Position of the first peak
peak2 = (11, 11)  # Position of the second peak
sigma1 = 1.0  # Standard deviation of the first Gaussian
sigma2 = 1.0  # Standard deviation of the second Gaussian

# Create the double Gaussian kernel
kernel = create_double_gaussian_kernel(kernel_size, peak1, peak2, sigma1, sigma2)

# Normalize the kernel
kernel /= np.sum(kernel)

print("Kernel shape:", kernel.shape)
print(kernel)

from scipy.ndimage import convolve

# Example array of shape (512, 512)
array = np.random.rand(512, 512)

# Convolve the array with the kernel
convolved_array = convolve(array, kernel)

print("Convolved array shape:", convolved_array.shape)


In [None]:
# custom transforms 

class Double_tip1(object):
  '''
  Add a double tip artefact to the image. The parameters of the double tip are chosen randomly so
  each image will have a different double tip artefact.
  '''

  def __init__(self):
    pass
  
  def sigmoid(x, a,b):
    '''
    A sigmoid function that takes in an array x and parameters a and b.
    This is used to rescale the pixel values in the "doubled" image (the brighter ones 
    are often the ones that are doubled more clearly).
    '''
    return 1 / (1 + np.exp(a-b*x))

  def __call__(array,a,b):
      # Apply the sigmoid function to each pixel value
      sigmoid_array = self.sigmoid(array,a,b)
      
      # Create an empty array to store the offset values
      offset_array = np.zeros_like(array)
      
      # Get the shape of the array
      rows, cols = array.shape
      
      # Offset each pixel by a random value in x and y (between 5 and 10)
      # Generate random offsets for x and y
      offset_x = np.random.randint(2, 11)
      offset_y = np.random.randint(2, 11)
      # decide if it's positive or negative
      if np.random.rand() < 0.5:
          offset_x = -offset_x
      if np.random.rand() < 0.5:
          offset_y = -offset_y
      print('offset in x and y: ', offset_x, offset_y)
      for i in range(rows):
          for j in range(cols):
              # Calculate the new position
              new_i = (i + offset_x) % rows
              new_j = (j + offset_y) % cols
              
              # Add the sigmoid value to the offset array at the new position
              offset_array[new_i, new_j] += sigmoid_array[i, j]
      
      # random integer between 0 and 5
      median_size = np.random.randint(1, 11)
      print('median filter size: ', median_size)

      # Apply median filter to the offset image
      filtered_offset = median_filter(offset_array, size=median_size)  # You can adjust the size of the filter

      # Add the offset array to the original array
      result_array = array + offset_array
      
      return result_array


class RandomScanLineArtefact(object):
    '''
    Randomly adds scan line artefacts to the image.
    Does a few different types of scan line artefacts:
    - Adds a constant to a single line in the image but for a longer length
    - Adds a constant to two lines in the image but for a longer length 
    - Adds a sinusoidal wave to a single line in the image
    - Adds a sinusoidal wave to two lines in the image
    - Adds a constant to a single line in the image but for a shorter length
    '''
    def __init__(self, p):
        self.p = p

    def __call__(self, scan):
        r1 = random.random()
        r2 = random.random()
        scan_maxmin = (torch.clone(scan)-torch.min(scan))/(torch.max(scan)-torch.min(scan))

        if r1 < self.p:
          rng = np.random.default_rng(12345) # random number generator
          res = scan.shape[1]
          num_lines = 15 # number of lines to add artefact to
          lines = rng.integers(0,res, (num_lines,)) # which scan lines to augment
          columns = rng.integers(0,res, (num_lines,)) # the columns where the artefacts begin
          lengths = rng.integers(0, int(res*0.8), (num_lines,)) # the length of the artefacts
          add_ons = rng.random(size=(num_lines,))/1.67 # random number between 0 and ~0.6 to add on to a scan line
          # add constant to single line
          for i in range(7):
            scan_maxmin[:, lines[i], columns[i]:columns[i]+lengths[i]] = scan_maxmin[:, lines[i], columns[i]:columns[i]+lengths[i]] + add_ons[i]
          # add constant to two lines one
          for i in range(7,9):
            scan_maxmin[:, lines[i]:lines[i]+2, columns[i]:columns[i]+lengths[i]] = scan_maxmin[:, lines[i]:lines[i]+2, columns[i]:columns[i]+lengths[i]] + add_ons[i]
          # add sinusoidal to single line
          for i in range(9,13):
            end = rng.integers(200,314)/100
            lengths[i] = scan_maxmin[:, lines[i], columns[i]:columns[i]+lengths[i]].shape[0] # correct length in case its too long
            cos = np.cos(np.linspace(0, end, num=lengths[i]) )
            scan_maxmin[:, lines[i], columns[i]:columns[i]+lengths[i]] = scan_maxmin[:, lines[i], columns[i]:columns[i]+lengths[i]] + cos*add_ons[i]
          # add sinusoidal to two lines
          for i in range(13,15):
            end = rng.integers(200,314)/100
            lengths[i] = scan_maxmin[:, lines[i]:lines[i]+2, columns[i]:columns[i]+lengths[i]].shape[1] # correct length in case its too long
            cos = np.cos(np.linspace(0, end, num=lengths[i]) )
            scan_maxmin[:, lines[i]:lines[i]+2, columns[i]:columns[i]+lengths[i]] = scan_maxmin[:, lines[i]:lines[i]+2, columns[i]:columns[i]+lengths[i]] + cos*add_ons[i]

        if r2 < self.p:
          # add some shorter scan line artefacts
          rng = np.random.default_rng(12345) # random number generator
          res = scan.shape[1]
          lines = rng.integers(0,res, (10,)) # which scan lines to augment
          columns = rng.integers(0,res, (10,)) # the columns where the artefacts begin
          lengths = rng.integers(0, int(res*0.1) , (10,)) # the length of the artefacts
          add_ons = rng.random(size=(10,))/1.67 # random number between 0 and ~0.6 to add on to a scan line
          for i in range(10):
            scan_maxmin[:, lines[i], columns[i]:columns[i]+lengths[i]] =scan_maxmin[:, lines[i], columns[i]:columns[i]+lengths[i]] + add_ons[i]

        return scan_maxmin

# Custom transformation for random rotation
class RandomRotation_(object):
    '''
    Randomly rotates the image by an angle chosen from a list of angles.
    '''
    def __init__(self, angles):
        self.angles = angles
    def __call__(self, img):
        img.transpose(0,1)
        angle = random.choice(self.angles)
        img_ =  F.rotate(img, angle)
        img_.transpose(0,1)
        return img_

class RandomCreep(object):
    '''
    Randomly adds a creep to the top or bottom of the image
    '''
    def __init__(self, p):
        self.p = p

    def __call__(self, scan):
        r = random.random()
        r2 = random.random()
        if r < self.p:
          # generate 3 random ints
          num_lines = np.random.randint(40,50) # num lines the creep is visible in
          a1 = np.random.randint(0,50)
          a2 = np.random.randint(0,50) # factors in the creep polynomial
          scan_ = torch.clone(scan)
          if r2 < 0.5:
            # add creep to top
            for i in range(0,num_lines):
                j = num_lines-i
                roll = a1*i*i//250 + a2*i*i//100 +i*i//50
                if scan_.dim() == 3:                
                    scan_[:,2*j:2*(j+1)] = torch.roll(scan[:,2*j:2*(j+1)], -roll, dims=1)
                else:
                    scan_[2*j:2*(j+1)] = torch.roll(scan[2*j:2*(j+1)], -roll, dims=0)
        if r2<0.5:
            # flip the image so the creep is on the bottom
            scan_ = torch.flip(scan_, [1])  
        return scan_
    
    
class MedianFilter(object):
    '''
    Applies a median filter to the image.
    '''
    def __init__(self, size):
        self.size = size

    def __call__(self, img):
        # Convert the image to a NumPy array
        img_np = np.array(img)
        
        # Apply the median filter
        filtered_img_np = median_filter(img_np, size=self.size)
        
        # Convert the filtered image back to a tensor
        filtered_img = torch.tensor(filtered_img_np)
        
        return filtered_img
    

In [None]:
class STM_double_tip_dataset(Dataset):
    def __init__(self, image_dir, empty = False):
        '''
        Args:
            image_dir (string): Directory with all the images. These are assumed as being numpy arrays. 
                                Either with shape (res,res,2) (filled and empty), or just (res,res) (filled only).
            empty (bool): If True, the images are assumed as having filled and empty state images and both are wanted.
                          If False, we take only the filled state images.
        '''
        self.empty = empty
        self.image_dir = image_dir
        # check all files in image_dir have the same shape
        shapes = [len(np.load(f).shape) for f in os.listdir(image_dir)]
        if all(s==3 for s in shapes):
            pass
        elif all(s==2 for s in shapes):
            pass
        else:
            raise ValueError('All files in image_dir must have the same shape. Either (res,res,2) or (res,res).')
        # load the data
        shape = shapes[0]
        if empty and shape == 3:
            self.image_files = [f for f in os.listdir(image_dir) if os.path.isfile(os.path.join(image_dir, f))]
        elif empty and shape == 2:
            raise ValueError('empty=True only makes sense if the images have two channels.')
        elif not empty and shape == 2:
            self.image_files = [f for f in os.listdir(image_dir) if os.path.isfile(os.path.join(image_dir, f))]
        elif not empty and shape == 3:
            self.image_files = [f[:,:,0] for f in os.listdir(image_dir) if os.path.isfile(os.path.join(image_dir, f))]
        

       # self.transform = transform
        '''
        Transforms order: random blur, random rotation, random brightness, creep, double tip, crop, resize
        '''

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        idx = idx % self.__len__()
        img_name = os.path.join(self.image_dir, self.image_files[idx])
        image = torch(np.open(img_name)).toTensor()
        label = torch.clone(image)

        return image, label

In [4]:
20%100

20

In [5]:
200%20

0