<a href="https://colab.research.google.com/github/matsunagalab/ColabBTR/blob/main/ColabBTR.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

##ColabBTR: Blind tip reconstruction on Google Colab

Easy to use notebook for using end-to-end differentiable blind tip reconstruction and removing the tip convoution artifacts from you AFM data.

*   This notebook can read and write several file formats of AFM data
*   Perform end-to-end differentiable blind tip reconsruction to reconstruct tip shape only from AFM data
*   Perform erosion (deconvoulution) with the reconstructed tip to remove tip convolution artifacts


[Y. Matsunaga, S. Fuchigami, T. Ogane, S. Takada. End-to-end differentiable blind tip reconstruction for noisy atomic force microscopy images.
*Scientific Reports*, 2023](https://www.nature.com/articles/s41598-022-27057-2)

In [1]:
#@title Install dependencies
%%time
import os

print("installing colabbtr...")
os.system("pip install -q --no-warn-conflicts git+https://github.com/matsunagalab/ColabBTR")

print("installing libasd...")
os.system("pip install -q --no-warn-conflicts libasd")

print("upgrading gdown...")
os.system("pip install -q --upgrade --no-warn-conflicts gdown")

print("installing spmpy...")
os.system("pip install -q --no-warn-conflicts git+https://github.com/kogens/spmpy")

print("installing gwyfile...")
os.system("pip install -q --no-warn-conflicts gwyfile")

print("✨All installation completed!✨")


installing colabbtr...
installing libasd...
upgrading gdown...
installing spmpy...
installing gwyfile...
✨All installation completed!✨
CPU times: user 78.9 ms, sys: 13 ms, total: 91.9 ms
Wall time: 17.5 s


In [2]:
#@title Upload AFM data (asd, gwy, spm, npy, npz, csv files are supported)
#@markdown Click the button that appears below this cell to upload your AFM file.
from google.colab import files
import os

def create_directory(path: str) -> None:
    """Create a directory if it doesn't exist."""
    os.makedirs(path, exist_ok=True)

def upload_and_move_file(upload_path: str) -> str:
    """Upload a single file and move it to the specified path."""
    uploaded = files.upload()
    if len(uploaded) != 1:
        raise ValueError("Please upload exactly one file.")

    filename = list(uploaded.keys())[0]
    destination = os.path.join(upload_path, filename)
    os.rename(filename, destination)
    return destination

#@markdown - **Jobname**: used for name of job or directory
jobname = 'test' #@param {type:"string"}
resolution_x = 1
resolution_y = 1

try:
    # Create job directory
    custom_path = os.path.join(jobname, "afm_data")
    create_directory(custom_path)

    # Upload and move file
    file_path = upload_and_move_file(custom_path)

    print(f"Job name: {jobname}")
    print(f"Upload path: {custom_path}")
    print(f"Uploaded file path: {file_path}")
    print(f"Resolution: {resolution_x}x{resolution_y}")

    # Save metadata
    metadata_path = os.path.join(custom_path, "metadata.txt")
    with open(metadata_path, "w") as f:
        f.write(f"Job name: {jobname}\n")
        f.write(f"Resolution: {resolution_x}x{resolution_y}\n")
        f.write(f"Uploaded file path: {file_path}\n")

except Exception as e:
    print(f"An error occurred: {str(e)}")

Saving images.npy to images.npy
Job name: test
Upload path: test/afm_data
Uploaded file path: test/afm_data/images.npy
Resolution: 1x1


In [3]:
#@title Load the uploaded data
import torch
import libasd
import numpy as np
import matplotlib.pyplot as plt
import gwyfile
import os

#@markdown - **Channel number** (option for 2ch *.asd* file)<br>
#@markdown example: 0
channel_number_in_asd = 0#@param {type:"raw"}
#@markdown - **Name of channels** (option for *.gwy* file)<br>
#@markdown example: topography
channel_name_in_gwyfile = "topography"#@param {type:"string"}

# Save parameter file
custom_path = os.path.join('/content', jobname, 'afm_data')
with open(os.path.join(custom_path, 'upload_file_params.txt'), 'w') as f:
    f.write(f'channel_number_in_asd: {channel_number_in_asd}\n')
    f.write(f'channel_name_in_gwyfile: {channel_name_in_gwyfile}\n')

def load_asd_file(file_path, channel_number):
    data = libasd.read_asd(file_path)
    print(f"ASD file version: {data.header.file_version}")

    if len(data.channels) == 1:
        nframe = len(data.frames)
        images = np.array([frame.image() for frame in data.frames])
    elif len(data.channels) == 2:
        channel = data.channels[channel_number]
        nframe = len(channel)
        images = np.array([frame.image() for frame in channel])
    else:
        raise ValueError("Unsupported number of channels in ASD file")

    return images

def load_gwy_file(file_path, channel_name):
    obj = gwyfile.load(file_path)
    channels = gwyfile.util.get_datafields(obj)
    channel = channels[channel_name]
    return np.expand_dims(channel.data, axis=0)

def load_numpy_file(file_path):
    return np.load(file_path)

# Dictionary mapping file extensions to their respective loading functions
loaders = {
    '.asd': load_asd_file,
    '.gwy': load_gwy_file,
    '.npy': load_numpy_file,
    '.npz': load_numpy_file
}

# Load the data
file_extension = os.path.splitext(file_path)[-1].lower()

if file_extension not in loaders:
    raise ValueError(f"Unsupported file type: {file_extension}")

loader_func = loaders[file_extension]
if file_extension == '.asd':
    images = loader_func(file_path, channel_number_in_asd)
elif file_extension == '.gwy':
    images = loader_func(file_path, channel_name_in_gwyfile)
else:
    images = loader_func(file_path)

# Create tensors
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
images = torch.tensor(images, dtype=torch.float32, device=device)
nframe = images.shape[0]

# Display information
print(f"Loaded file: {file_path}")
print(f"Image shape: {images.shape}")
print(f"Number of frames: {nframe}")
print(f"Pixel height: {images.shape[1]}")
print(f"Pixel width: {images.shape[2]}")

# Optional: Save the loaded data
np.save(os.path.join(custom_path, 'loaded_images.npy'), images.cpu().numpy())

Loaded file: test/afm_data/images.npy
Image shape: torch.Size([576, 30, 30])
Number of frames: 576
Pixel height: 30
Pixel width: 30


In [8]:
#@title Visualize the data
from google.colab import output
output.enable_custom_widget_manager()

import plotly.graph_objects as go
from IPython.display import display, HTML
import ipywidgets as widgets


#@markdown - **Output frame range**
first_frame = 0#@param {type:"raw"}
nframe = images.shape[0]
last_frame = nframe#@param {type:"raw"}
#@markdown - **Colormap range**
vmin = images.min().item()#@param {type:"raw"}
vmax = images.max().item()#@param {type:"raw"}
#@markdown - **Output image size**
figsize = 600#@param {type:"raw"}

# Save parameter file
custom_path = os.path.join('/content', jobname, f'afm_data')
with open(f'{custom_path}/visualize_input_params.txt', 'w') as f:
    f.write(f'first_frame: {first_frame}\n')
    f.write(f'last_frame: {last_frame}\n')
    f.write(f'vmin: {vmin}\n')
    f.write(f'vmax: {vmax}\n')
    f.write(f'figsize: {figsize}\n')

# Create the initial plot
fig = go.FigureWidget()

heatmap = go.Heatmap(
    z=images[first_frame].cpu().numpy(),
    zmin=vmin,
    zmax=vmax,
    colorscale='Hot',
    colorbar=dict(title='Height')
)

fig.add_trace(heatmap)

fig.update_layout(
    title=f'AFM Image Visualization - Frame {first_frame}',
    width=figsize,
    height=figsize,
    xaxis_title='X',
    yaxis_title='Y',
)

# Create a slider widget
slider = widgets.IntSlider(
    value=first_frame,
    min=first_frame,
    max=last_frame-1,
    step=1,
    description='Frame:',
    continuous_update=False,
    layout=widgets.Layout(width='80%')
)

# Function to update the plot based on slider value
def update_plot(change):
    frame = change['new']
    with fig.batch_update():
        fig.data[0].z = images[frame].cpu().numpy()
        fig.layout.title.text = f'AFM Image Visualization - Frame {frame}'

# Connect the slider to the update function
slider.observe(update_plot, names='value')

# Display the slider and the plot
display(widgets.VBox([slider, fig]))

VBox(children=(IntSlider(value=0, continuous_update=False, description='Frame:', layout=Layout(width='80%'), m…

In [9]:
#@title (Optional) Correct stage tilt by RANSAC

from google.colab import output
output.enable_custom_widget_manager()

import torch
import numpy as np
from sklearn.linear_model import RANSACRegressor
import ipywidgets as widgets
import plotly.graph_objects as go
from plotly.subplots import make_subplots

custom_path = os.path.join('/content', jobname, f'stage_tilt')
os.makedirs(custom_path, exist_ok=True)

def correct_afm_tilt_multi_frame(images):
    np_images = images.cpu().numpy()

    # Create (x, y, z) coordinates for all frames combined
    nframes, height, width = np_images.shape
    y_grid, x_grid = np.meshgrid(np.arange(height), np.arange(width), indexing='ij')
    coordinates = np.stack([np.repeat(x_grid.ravel(), nframes),
                            np.repeat(y_grid.ravel(), nframes),
                            np_images.reshape(-1)], axis=1)

    # Perform plane fitting using RANSAC for all frames combined
    ransac = RANSACRegressor()
    ransac.fit(coordinates[:, :2], coordinates[:, 2])

    # Get parameters of the plane (ax + by + c = z)
    a, b = ransac.estimator_.coef_
    c = ransac.estimator_.intercept_

    # Calculate corrected heights for all frames
    corrected_images = np_images - (a * x_grid + b * y_grid + c)

    # Identify inliers and outliers for all frames
    inlier_mask = ransac.inlier_mask_.reshape(nframes, height, width)

    # Convert corrected_images and inlier_mask back to PyTorch tensors with the same precision and device as input
    corrected_images_tensor = torch.tensor(corrected_images, dtype=images.dtype, device=images.device)
    inlier_mask_tensor = torch.tensor(inlier_mask, dtype=torch.bool, device=images.device)

    return corrected_images_tensor, inlier_mask_tensor

# Assuming `corrected_heights` and `inlier_mask` are obtained from the above function
corrected_images, inlier_mask = correct_afm_tilt_multi_frame(images)
np.save(f'{custom_path}/corrected_images', corrected_images.cpu().numpy())
np.save(f'{custom_path}/inlier_mask', inlier_mask.cpu().numpy())

images = corrected_images
#@markdown - **Frame range**: for visualization
first_frame = 0#@param {type:"raw"}
last_frame = nframe#@param {type:"raw"}
#@markdown - **Colormap range**
vmin = images.min().item()#@param {type:"raw"}
vmax = images.max().item()#@param {type:"raw"}
#@markdown - **Output image size**
figsize = 600#@param {type:"raw"}

# Save parameter file
with open(f'{custom_path}/stage_tilt_params.txt', 'w') as f:
    f.write(f'first_frame: {first_frame}\n')
    f.write(f'last_frame: {last_frame}\n')
    f.write(f'vmin: {vmin}\n')
    f.write(f'vmax: {vmax}\n')
    f.write(f'figsize: {figsize}\n')

# Create the initial plot
fig = make_subplots(rows=1, cols=2, subplot_titles=('Corrected Image', 'Inlier Mask'))

heatmap_corrected = go.Heatmap(
    z=images[first_frame].cpu().numpy(),
    zmin=vmin,
    zmax=vmax,
    colorscale='Hot',
    colorbar=dict(title='Height', x=0.45)
)

heatmap_inlier = go.Heatmap(
    z=inlier_mask[first_frame].cpu().numpy(),
    zmin=0,
    zmax=1,
    colorscale=[[0, 'rgb(255,0,0)'], [1, 'rgb(0,255,0)']],  # Red for outliers, Green for inliers
    colorbar=dict(title='Inlier', x=1.0,
                  tickvals=[0, 1],
                  ticktext=['Outlier', 'Inlier'])
)

fig.add_trace(heatmap_corrected, row=1, col=1)
fig.add_trace(heatmap_inlier, row=1, col=2)

fig.update_layout(
    title=f'RANSAC Tilt Correction - Frame {first_frame}',
    width=figsize * 2,
    height=figsize,
)

# Ensure both subplots have the same aspect ratio
fig.update_xaxes(scaleanchor="y", scaleratio=1, row=1, col=1)
fig.update_xaxes(scaleanchor="y", scaleratio=1, row=1, col=2)

# Create a slider widget
slider_ransac = widgets.IntSlider(
    value=first_frame,
    min=first_frame,
    max=last_frame-1,
    step=1,
    description='Frame:',
    continuous_update=False,
    layout=widgets.Layout(width='80%')
)

# Wrap the Plotly figure in a FigureWidget
fig_widget = go.FigureWidget(fig)

# Function to update the plot based on slider value
def update_plot(change):
    frame = change['new']
    with fig_widget.batch_update():
        fig_widget.data[0].z = images[frame].cpu().numpy()
        fig_widget.data[1].z = inlier_mask[frame].cpu().numpy()
        fig_widget.layout.title.text = f'RANSAC Tilt Correction - Frame {frame}'

# Connect the slider to the update function
slider_ransac.observe(update_plot, names='value')

# Display the slider and the plot
display(widgets.VBox([slider_ransac, fig_widget]))

VBox(children=(IntSlider(value=0, continuous_update=False, description='Frame:', layout=Layout(width='80%'), m…

In [None]:
#@title Working in progress... (Optional) Determine optimal weight decay for AdamW by cross-validation


In [11]:
#@title Run the end-to-end differentiable BTR and reconstruct tip shape from AFM images

%%time
from colabbtr.morphology import differentiable_btr
import plotly.graph_objects as go
import plotly.io as pio
pio.renderers.default = "colab"

custom_path = os.path.join('/content', jobname, f'differentiable_BTR')
os.makedirs(custom_path, exist_ok=True)

#@markdown - **Frame range**: used for estimation of tip shape
first_frame = 0#@param {type:"raw"}
last_frame = 30#@param {type:"raw"}
#@markdown - **Tip size**: used for tip height and width
tip_size = 10#@param {type:"raw"}
tip_height = tip_size
tip_width = tip_size
#@markdown - **Settings for learning**
epoch = 200#@param {type:"raw"}
learning_rate = 0.1#@param {type:"raw"}
weight_decay = 0.0#@param {type:"raw"}

# Save parameter file
with open(f'{custom_path}/differentiable_BTR_params.txt', 'w') as f:
    f.write(f'first_frame: {first_frame}\n')
    f.write(f'last_frame: {last_frame}\n')
    f.write(f'tip_size: {tip_size}\n')
    f.write(f'epoch: {epoch}\n')
    f.write(f'learning_rate: {learning_rate}\n')
    f.write(f'weight_decay: {weight_decay}\n')

tip, loss = differentiable_btr(images[first_frame:last_frame, :, :],
                             (tip_height, tip_width),
                             nepoch=epoch, lr=learning_rate, weight_decay=weight_decay)
np.save(f'{custom_path}/tip', tip.to('cpu').numpy())
np.save(f'{custom_path}/loss', loss)

# Plot interactive loss function
fig_loss = go.Figure()
fig_loss.add_trace(go.Scatter(y=loss, mode='lines', name='Loss'))
fig_loss.update_layout(title='Loss function',
                       xaxis_title='Epoch',
                       yaxis_title='Loss',
                       width=800,
                       height=500)
fig_loss.show()

# Plot interactive 3D tip shape
fig_3d = go.Figure(data=[go.Surface(z=tip.to('cpu').numpy())])
fig_3d.update_traces(contours_z=dict(show=True, usecolormap=True,
                                     highlightcolor="limegreen", project_z=True))
fig_3d.update_layout(title='Tip shape 3D', autosize=False,
                     width=600, height=500,
                     margin=dict(l=65, r=50, b=65, t=50))
fig_3d.show()

  0%|          | 0/200 [00:00<?, ?it/s]

CPU times: user 18.5 s, sys: 408 ms, total: 18.9 s
Wall time: 23.7 s


In [17]:
#@title Run erosion (deconvolution) with the reconstructed tip
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from IPython.display import display
import ipywidgets as widgets
from colabbtr.morphology import ierosion
import numpy as np
import os

custom_path = os.path.join('/content', jobname, f'erosion')
os.makedirs(custom_path, exist_ok=True)

surface = torch.zeros_like(images, device=device)
for iframe in range(nframe):
    surface[iframe, :, :] = ierosion(images[iframe, :, :], tip)

np.save(f'{custom_path}/surface', surface.to('cpu').numpy())

#@markdown - **Frame range**: used for erosion
first_frame = 0#@param {type:"raw"}
nframe = images.shape[0]
last_frame = 30#@param {type:"raw"}
#@markdown - **Colormap range**
vmin = surface.min().item()#@param {type:"raw"}
vmax = surface.max().item()#@param {type:"raw"}
#@markdown - **Output image size**
figsize = 600#@param {type:"raw"}

# Save parameter file
with open(f'{custom_path}/erosion_params.txt', 'w') as f:
    f.write(f'first_frame: {first_frame}\n')
    f.write(f'last_frame: {last_frame}\n')
    f.write(f'vmin: {vmin}\n')
    f.write(f'vmax: {vmax}\n')
    f.write(f'figsize: {figsize}\n')

# Create the initial plot
fig = go.FigureWidget()

heatmap = go.Heatmap(
    z=surface[first_frame].cpu().numpy(),
    zmin=vmin,
    zmax=vmax,
    colorscale='Hot',
    colorbar=dict(title='Height')
)

fig.add_trace(heatmap)

fig.update_layout(
    title=f'Eroded surface - Frame {first_frame}',
    width=figsize,
    height=figsize,
    xaxis_title='X',
    yaxis_title='Y',
)

# Create a slider widget
slider = widgets.IntSlider(
    value=first_frame,
    min=first_frame,
    max=last_frame-1,
    step=1,
    description='Frame:',
    continuous_update=False,
    layout=widgets.Layout(width='80%')
)

# Function to update the plot based on slider value
def update_plot(change):
    frame = change['new']
    with fig.batch_update():
        fig.data[0].z = surface[frame].cpu().numpy()
        fig.layout.title.text = f'Eroded surface - Frame {frame}'

# Connect the slider to the update function
slider.observe(update_plot, names='value')

# Display the slider and the plot
display(widgets.VBox([slider, fig]))

VBox(children=(IntSlider(value=0, continuous_update=False, description='Frame:', layout=Layout(width='80%'), m…

In [None]:
#@title Work in progress... (Optional) Tip shape reconstruction across frames (used for anomaly detection of tip shapes)

from colabbtr.morphology import differentiable_btr
import torch
from tqdm.notebook import tqdm
from skimage.metrics import structural_similarity as ssim
import matplotlib.pyplot as plt
import numpy as np

custom_path = os.path.join('/content', jobname, f'SSIM_score')
os.makedirs(custom_path, exist_ok=True)

# Assuming 'images' is a tensor with shape [nframes, height, width]
#@markdown - **Frame range**
window_frames = 10#@param {type:"raw"}
first_frame = 0#@param {type:"raw"}
nframe = images.shape[0]
last_frame = nframe#@param {type:"raw"}
#@markdown - **Tip size**: used for tip height and width
tip_size = 10#@param {type:"raw"}
tip_height = tip_size
tip_width = tip_size
#@markdown - **Settings for learning**
epoch = 200#@param {type:"raw"}
learning_rate = 0.1#@param {type:"raw"}
weight_decay = 0.0#@param {type:"raw"}

# Save parameter file
with open(f'{custom_path}/SSIM_params.txt', 'w') as f:
  f.write(f'window_frames: {window_frames}\n')
  f.write(f'first_frame: {first_frame}\n')
  f.write(f'last_frame: {last_frame}\n')
  f.write(f'tip_size: {tip_size}\n')
  f.write(f'epoch: {epoch}\n')
  f.write(f'learning_rate: {learning_rate}\n')
  f.write(f'weight_decay: {weight_decay}\n')

# Calculate the number of windows
num_windows = (last_frame - first_frame) - window_frames + 1

# Initialize the tensor to store the tips for each window
tips = torch.zeros((last_frame - first_frame, tip_height, tip_width), device=images.device, dtype=images.dtype)

for i in tqdm(range(first_frame, last_frame)):
    start_frame = i
    end_frame = i + window_frames

    # Extract the frames for the current window
    window_images = images[start_frame:end_frame, :, :]

    # Perform differentiable BTR to estimate the tip shape for the current window
    tip_shape, loss = differentiable_btr(window_images,
                                         (tip_height, tip_width),
                                         nepoch=epoch, lr=learning_rate, weight_decay=weight_decay, is_tqdm=False)

    # Store the estimated tip shape
    tips[i, :, :] = tip_shape

# Function to calculate SSIM for each estimated tip shape against the reference tip shape
def calculate_ssim(tips):
    reference_tip = tips[0].cpu().numpy()
    ssim_scores = [ssim(reference_tip, tip.cpu().numpy(), channel_axis=False) for tip in tips]

    return ssim_scores

# Calculate SSIM scores
ssim_scores = calculate_ssim(tips)
np.save(f'{custom_path}/ssim_scores', ssim_scores)

# Plot SSIM scores
plt.figure(figsize=(10, 5))
plt.plot(ssim_scores, marker='o')
plt.xlabel('Slidig window frame')
plt.ylabel('SSIM with first window tip')
plt.title('SSIM scores of estimated tip shapes across frames')
plt.grid(True)
plt.savefig(f'{custom_path}/SSIM_scores.png')
plt.show()

# Identify frames with potential anomalies based on SSIM threshold
threshold = 0.4  # Example threshold, adjust based on your dataset
anomaly_frames = [i for i, score in enumerate(ssim_scores) if score < threshold]
print("Potential anomaly frames:", anomaly_frames)
with open(f'{custom_path}/SSIM_anomaly_frames.txt', 'w') as f:
  f.write(f'Potential anomaly frames: {anomaly_frames}')

  0%|          | 0/576 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
#@title Download results
from google.colab import files
import datetime

with open(f'/content/{jobname}/info.txt', 'w') as f:
  f.write(f'Creation date and time: {datetime.datetime.now()}')

os.system(f"wget -P /content/{jobname} -O /content/{jobname}/LICENSE https://raw.githubusercontent.com/matsunagalab/ColabBTR/main/LICENSE")

# zip圧縮してダウンロード
os.system(f"cd /content")
os.system(f"zip -r {jobname}_result.zip {jobname}")
files.download(f"{jobname}_result.zip")


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>