In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

from pathlib import Path
import platform
from smb.SMBConnection import SMBConnection
from tqdm import tqdm
import os
import io
import random
from PIL import Image
import numpy as np
import cv2
import matplotlib.pyplot as plt

class Smb():
    def __init__(self, username, password, remote_name, ip):
        self.conn = SMBConnection(
            username, password, platform.node(), remote_name)
        self.ip = ip

    def __enter__(self):
        self.conn.connect(self.ip, 139)
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        self.conn.close()

    def echo(self, data):
        return self.conn.echo(data)

    def get_item_list(self, share_dir:str, target_dir:str):
        items = self.conn.listPath(share_dir, target_dir)
        return [item.filename for item in items]
    
    def copy_file(self, share_dir:str, file_path:str, dst_dir:str, exists_ok=True):
        os.makedirs(dst_dir, exist_ok=exists_ok)
        dst_file_path = Path(dst_dir)/Path(file_path).name
        with open(dst_file_path, 'wb') as fp:
            self.conn.retrieveFile(share_dir, file_path, fp)
        return None
    
    def copy_dir(self, share_dir:str, dir_path:str, dst_dir:str):
        os.makedirs(dst_dir, exist_ok=True)

        dir = self.conn.listPath(share_dir, dir_path)
        for e in tqdm(dir):
            remote_filepath = str(Path(dir_path)/Path(e.filename))
            save_filepath   = str(Path(dst_dir)/Path(e.filename))
            if not e.isDirectory:
                with open(save_filepath, 'wb') as fp:
                    self.conn.retrieveFile(share_dir, remote_filepath, fp)

            elif e.filename not in ['.', '..']:
                self.copy_dir(share_dir, remote_filepath, save_filepath)

        return None

In [None]:
def show_img(img, title=None):
    fig, ax = plt.subplots(nrows=1, ncols=1, dpi=150)
    ax.set_title( title, fontsize=16, color='black')
    ax.axes.xaxis.set_visible(False) # X軸を非表示に
    ax.axes.yaxis.set_visible(False) # Y軸を非表示に
    ax.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))

    return fig, ax

def show_imgs(imgs_dict:dict, ncol=0, dpi=400):
    if ncol != 0:
        nrow = ((len(imgs_dict)-1)//ncol)+1
    else:
        nrow = 1
        ncol = len(imgs_dict)

    fig = plt.figure(dpi=dpi)
    for i, key in enumerate(imgs_dict, start=1):
        ax = fig.add_subplot(nrow, ncol, i)
        ax.axis('off')
        ax.imshow(cv2.cvtColor(imgs_dict[key], cv2.COLOR_BGR2RGB))
        ax.set_title(key, fontsize=3, color='black')

    return None

params = {
  'username': '',
  'password': '',
  'remote_name': 'XXX.XXX.XXX.XXX',
  'ip': 'XXX.XXX.XXX.XXX'
}

base_share_dir = Path("share-dataset")
base_share_subDir = Path("AAA/BBB/CCC/")

with Smb(**params) as smb:
  items = smb.conn.listPath(str(base_share_dir), str(base_share_subDir))
  sample_path = base_share_subDir / Path(random.choice(items).filename)

  with io.BytesIO() as file:
    smb.conn.retrieveFile(str(base_share_dir), str(sample_path), file)
    img = np.asarray(Image.open(file))
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  
  show_img(img)
