In [None]:
import numpy as np

import matplotlib.pyplot as plt
import os

from PIL import Image
from skimage import io, color, img_as_float, img_as_ubyte
from sklearn.model_selection import train_test_split


import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data
import torchvision
from torchvision import transforms
from torchvision.models import resnet50
from torchvision.datasets import ImageFolder

print("PyTorch Version: ",torch.__version__)
print("Torchvision Version: ",torchvision.__version__)

!pip install torchinfo
from torchinfo import summary

!nvidia-smi

PyTorch Version:  1.7.1
Torchvision Version:  0.8.2
Fri Feb  3 13:53:28 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 430.34       Driver Version: 430.34       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|   0  GeForce RTX 208...  Off  | 00000000:19:00.0 Off |                  N/A |
| 24%   29C    P8     3W / 250W |   1574MiB / 11019MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
|   1  GeForce RTX 208...  Off  | 00000000:1A:00.0 Off |                  N/A |
| 24%   33C    P8     6W / 250W |     11MiB / 11019MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
|   2  GeForce RTX 208...  Off  | 00

In [None]:
!python --version

root_path = "/media/kondo/Ext4_for_Colab/tomizawa/paper/" #@param {type:"string"}
#@markdown  - Assign a path for your root directory. This root directory should include an image dataset directory ("/data/data_original/") and "/utils/" directory.
#@markdown  - The image dataset is available at a database (URL in #README on the GitHub).
#@markdown  - The library "utils" is available at the same GitHub page with this jupyter notebook.

%cd $root_path 
!pwd
!ls

Python 3.6.10 :: Anaconda, Inc.
/media/kondo/Ext4_for_Colab/tomizawa/paper
/media/kondo/Ext4_for_Colab/tomizawa/paper
data	   figures  results_old		 results_test
data_old1  results  results_randomLabel  utils


In [None]:
from utils.loaders import Args, ImageTransform, MarchantiaDataset
from utils.misc import makedirs

In [None]:
dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(dev)

cuda


In [None]:
def visualize_batch(images, ncols=4):
    fig, ax = plt.subplots(nrows=1, ncols=ncols, figsize=(3*ncols, 3))
    for n in range(ncols):
        img = images[n].cpu().numpy()
        img = img.transpose((1, 2, 0))
        ax[n].imshow(img)
        ax[n].axis('off')

In [None]:
strains = ['Aus', 'Tak', 'RIL5']
days = ['0d', '1d', '2d', '3d', '4d', '7d']
ablation = 'original'
results_dname = 'figures/Marchantia_imgs/'
n_img = 0


makedirs(results_dname)
args = Args()
for strain in strains:
  for day in days:
    args.root = './data/data_' + ablation + '/' + strain + '/' + day + '/'
    print(args.__dict__)

    Transform = ImageTransform(fill=0)
    full_dataset = ImageFolder(root=args.root)
    print(full_dataset)
    print(full_dataset.class_to_idx)

    targets = np.array(full_dataset.targets)
    for MF in ['M', 'F']:
      if strain=='Tak':
        target_class = 0 if MF == 'M' else 1 if MF == 'F' else None
      else:
        target_class = 1 if MF == 'M' else 0 if MF == 'F' else None
      target_MF = np.where(targets == target_class)[0]

      dataset_noNormalize = MarchantiaDataset(full_dataset, Transform.data_transform["no_Normalize"])
      img_dataset = data.Subset(dataset_noNormalize, indices=target_MF)
      img_loader = data.DataLoader(img_dataset, batch_size=args.batch_size, shuffle=False)

      print(strain, day, MF)
      images, labels = next(iter(img_loader))
      # visualize_batch(images); plt.show()

      ## plot and save.
      fig, ax = plt.subplots(1,1, figsize=(3, 3))
      img = images[n_img].cpu().numpy()
      img = img.transpose((1, 2, 0))
      ax.imshow(img)
      ax.axis('off')
      fig.subplots_adjust(left=0, right=1, bottom=0, top=1)
      
      ## scale bar = 1mm in all conditions.
      a =  (1572-1152)*(224/800)/2 if day == '4d' else (1572-1032)*(224/800)/5 if day == '7d' else (1572-936)*(224/800)/2 # day = 0,1,2,3
      a = int(np.round(a))
      ax.plot([10, 10+a], [214,214], linewidth=1, color = 'black')

      fname = results_dname + strain + '_' + day + '_' + MF + '_' + ablation + '_' + str(n_img) + '.pdf'
      plt.savefig(fname, dpi = 300)
      plt.show()




Output hidden; open in https://colab.research.google.com to view.