In [None]:
pip install einops

In [None]:
pip install timm

In [None]:
import numpy as np
from preclassify import dicomp, hcluster
import cv2
import torch
import torch.nn as nn
import time
import numpy as np
from sklearn.manifold import TSNE
from scipy import io as sio
import matplotlib.pyplot as plt
import torch.nn.functional as F
import torch.optim as optim
from einops.layers.torch import Rearrange
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader
from timm.models.layers import trunc_normal_
from torch.autograd import Function
from torch_wavelets import DWT_2D, IDWT_2D
import pywt
from skimage import io, measure
import random
import math
from tqdm.notebook import tqdm
import os
from PIL import Image

ModuleNotFoundError: No module named 'sklearn'

In [None]:
def pair(t):
    return t if isinstance(t, tuple) else (t, t)


seed = 3407

def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True


seed_everything(seed)


In [None]:
def image_normalize(data):
    import math
    _mean = np.mean(data)
    _std = np.std(data)
    npixel = np.size(data) * 1.0
    min_stddev = 1.0 / math.sqrt(npixel)
    return (data - _mean) / max(_std, min_stddev)


def image_padding(data, r):
    if len(data.shape) == 3:
        data_new = np.lib.pad(data, ((r, r), (r, r), (0, 0)), 'constant', constant_values=0)
        return data_new
    if len(data.shape) == 2:
        data_new = np.lib.pad(data, r, 'constant', constant_values=0)
        return data_new


def arr(length):
    arr = np.arange(length - 1)
    # print(arr)
    random.shuffle(arr)
    # print(arr)
    return arr


def createTrainingCubes(X, y, patch_size):
    margin = int((patch_size - 1) / 2) + 1
    zeroPaddedX = image_padding(X, margin)
    ele_num1 = np.sum(y == 1)
    ele_num2 = np.sum(y == 2)
    patchesData_1 = np.zeros((ele_num1, patch_size, patch_size, X.shape[2]))
    patchesLabels_1 = np.zeros(ele_num1)
    patchesData_2 = np.zeros((ele_num2, patch_size, patch_size, X.shape[2]))
    patchesLabels_2 = np.zeros(ele_num2)

    patchIndex_1 = 0
    patchIndex_2 = 0
    for r in range(margin, zeroPaddedX.shape[0] - margin):
        for c in range(margin, zeroPaddedX.shape[1] - margin):
            # remove uncertainty pixels
            if y[r - margin, c - margin] == 1:
                patch_1 = zeroPaddedX[r - margin:r + margin, c - margin:c + margin]
                patchesData_1[patchIndex_1, :, :, :] = patch_1
                patchesLabels_1[patchIndex_1] = y[r - margin, c - margin]
                patchIndex_1 = patchIndex_1 + 1
            elif y[r - margin, c - margin] == 2:
                patch_2 = zeroPaddedX[r - margin:r + margin, c - margin:c + margin]
                patchesData_2[patchIndex_2, :, :, :] = patch_2
                patchesLabels_2[patchIndex_2] = y[r - margin, c - margin]
                patchIndex_2 = patchIndex_2 + 1
    patchesLabels_1 = patchesLabels_1 - 1
    patchesLabels_2 = patchesLabels_2 - 1

    arr_1 = arr(len(patchesData_1))
    arr_2 = arr(len(patchesData_2))
    train_len = 10000
    pdata = np.zeros((train_len, patch_size, patch_size, X.shape[2]))
    plabels = np.zeros(train_len)
    for i in range(7000):
        pdata[i, :, :, :] = patchesData_1[arr_1[i], :, :, :]
        plabels[i] = patchesLabels_1[arr_1[i]]
    for j in range(7000, train_len):
        pdata[j, :, :, :] = patchesData_2[arr_2[j - 7000], :, :, :]
        plabels[j] = patchesLabels_2[arr_2[j - 7000]]

    return pdata, plabels


def createTestingCubes(X, patch_size):
    margin = int((patch_size - 1) / 2) + 1
    zeroPaddedX = image_padding(X, margin)
    patchesData = np.zeros((X.shape[0] * X.shape[1], patch_size, patch_size, X.shape[2]))
    patchIndex = 0
    for r in range(margin, zeroPaddedX.shape[0] - margin):
        for c in range(margin, zeroPaddedX.shape[1] - margin):
            patch = zeroPaddedX[r - margin:r + margin, c - margin:c + margin]
            patchesData[patchIndex, :, :, :] = patch
            patchIndex = patchIndex + 1
    return patchesData


#  Inputs:  gtImg  = ground truth image
#           tstImg = change map
#  Outputs: FA  = False alarms
#           MA  = Missed alarms
#           OE  = Overall error
#           PCC = Overall accuracy
def evaluate(gtImg, tstImg):
    gtImg[np.where(gtImg > 128)] = 255
    gtImg[np.where(gtImg < 128)] = 0
    tstImg[np.where(tstImg > 128)] = 255
    tstImg[np.where(tstImg < 128)] = 0
    [ylen, xlen] = gtImg.shape
    FA = 0
    MA = 0
    label_0 = np.sum(gtImg == 0)
    label_1 = np.sum(gtImg == 255)
    print(label_0)
    print(label_1)

    for j in range(0, ylen):
        for i in range(0, xlen):
            if gtImg[j, i] == 0 and tstImg[j, i] != 0:
                FA = FA + 1
            if gtImg[j, i] != 0 and tstImg[j, i] == 0:
                MA = MA + 1

    OE = FA + MA
    PCC = 1 - OE / (ylen * xlen)
    PRE = ((label_1 + FA - MA) * label_1 + (label_0 + MA - FA) * label_0) / ((ylen * xlen) * (ylen * xlen))
    KC = (PCC - PRE) / (1 - PRE)
    print(' Change detection results ==>')
    print(' ... ... FP:  ', FA)
    print(' ... ... FN:  ', MA)
    print(' ... ... OE:  ', OE)
    print(' ... ... PCC: ', format(PCC * 100, '.2f'))
    print(' ... ... KC: ', format(KC * 100, '.2f'))


def postprocess1(res):
    res_new = res
    res = measure.label(res, connectivity=2)
    # print(res)
    num = res.max()
    # print(num)
    for i in range(1, num + 1):
        idy, idx = np.where(res == i)
        if len(idy) <= 15:
            res_new[idy, idx] = 0.5
    return res_new


def postprocess(res):
    res_new = res
    res = measure.label(res, connectivity=2)
    # print(res)
    num = res.max()
    # print(num)
    for i in range(1, num + 1):
        idy, idx = np.where(res == i)
        if len(idy) <= 20:
            res_new[idy, idx] = 0
    return res_new

In [None]:

im1_path = '/content/Sulzberger2_1.bmp'
im2_path = '/content/Sulzberger2_2.bmp'
imgt_path = '/content/Sulzberger2_gt.bmp'


patch_size = 8


if im1_path == '/content/Yellow_River_1.bmp':
  im1 = io.imread(im1_path).astype(np.float32)
  im2 = io.imread(im2_path).astype(np.float32)
  im_gt = io.imread(imgt_path).astype(np.float32)
else:
  im1 = io.imread(im1_path)[:, :, 0].astype(np.float32)
  im2 = io.imread(im2_path)[:, :, 0].astype(np.float32)
  im_gt = io.imread(imgt_path)[:, :, 0].astype(np.float32)



im_di = dicomp(im1, im2)
ylen, xlen = im_di.shape
pix_vec = im_di.reshape([ylen * xlen, 1])

preclassify_lab = hcluster(pix_vec, im_di)
print('... ... hiearchical clustering finished !!!')

mdata = np.zeros([im1.shape[0], im1.shape[1], 3], dtype=np.float32)
mdata[:, :, 0] = im1
mdata[:, :, 1] = im2
mdata[:, :, 2] = im_di
mlabel = preclassify_lab

x_train, y_train = createTrainingCubes(mdata, mlabel, patch_size)
x_train = x_train.transpose(0, 3, 1, 2)
# print(y_train.shape) #[10000]
x_test = createTestingCubes(mdata, patch_size)
x_test = x_test.transpose(0, 3, 1, 2)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class TrainDS(torch.utils.data.Dataset):
    def __init__(self):
        self.len = x_train.shape[0]
        self.x_data = torch.FloatTensor(x_train)
        self.y_data = torch.LongTensor(y_train)

    def __getitem__(self, index):
        return self.x_data[index], self.y_data[index]

    def __len__(self):
        return self.len


trainset = TrainDS()

train_loader = torch.utils.data.DataLoader(dataset=trainset, batch_size=128, shuffle=True, num_workers=2)




In [None]:
class WaveAttention(nn.Module):

    def __init__(self, sr_ratio, dim, heads, dropout):
        super(WaveAttention, self).__init__()
        self.heads = heads
        self.dim = dim
        head_dim = dim // heads
        self.head_dim = dim // heads
        self.sr_ratio = sr_ratio
        self.scale = head_dim ** -0.5
        self.dwt = DWT_2D(wave="haar")
        self.idwt = IDWT_2D(wave="haar")

        self.reduce = nn.Sequential(
            nn.Conv2d(dim, dim // 4, kernel_size=1, padding=0, stride=1),
            nn.BatchNorm2d(dim // 4),
            nn.ReLU(inplace=True)
        )
        self.filter = nn.Sequential(
            nn.Conv2d(dim, dim, kernel_size=3, padding=1, stride=1, groups=1),
            nn.BatchNorm2d(dim),
            nn.ReLU(inplace=True)
        )
        self.q = nn.Linear(dim, dim)
        self.kv = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, dim * 2)
        )

        self.kv_embed = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) if sr_ratio > 1 else nn.Identity()

        self.proj = nn.Sequential(
            nn.Linear(dim + dim // 4, dim),
            nn.Dropout(dropout)
        )

        self.apply(self._init_weights)


    def forward(self, x, H, W):
        B, N, C = x.shape
        q = self.q(x).reshape(B, N, self.heads, self.head_dim).permute(0, 2, 1, 3)
        # checkShape("q", q)
        x = x.view(B, H, W, C).permute(0, 3, 1, 2)
        x = self.reduce(x)
        x = torch.tensor(x, device=device).type(torch.float16)
        x_dwt = self.dwt(x)
        x_dwt = x_dwt.float()
        x_dwt = self.filter(x_dwt)

        x_dwt = x_dwt.half()
        x_idwt = self.idwt(x_dwt)
        x_idwt = x_idwt.view(B, -1, x_idwt.size(-2) * x_idwt.size(-1)).transpose(1, 2)

        x_dwt = x_dwt.float()
        kv = self.kv_embed(x_dwt).reshape(B, C, -1).permute(0, 2, 1)
        kv = self.kv(kv).reshape(B, -1, 2, self.heads, self.head_dim).permute(2, 0, 3, 1, 4)
        k = kv[0]
        v = kv[1]
        attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        x = torch.matmul(attn, v).transpose(1, 2).reshape(B, N, C)  # N ->H*W
        x = self.proj(torch.cat([x, x_idwt], dim=-1))
        return x

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()

class WSM(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool='cls', channels=3,
                 dim_head=64, dropout=0., emb_dropout=0.):
        super().__init__()
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)
        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be ' \
                                                                                    'divisible by the patch size. '
        num_patches = (image_height // patch_height) * (image_width // patch_width)
        patch_dim = channels * patch_height * patch_width
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=patch_height, p2=patch_width),
            nn.Linear(patch_dim, dim),
        )
        self.transformer = WaveAttention(sr_ratio=2, dim=dim, heads=heads, dropout=dropout)
        self.reshape = Rearrange('b (h w) (p1 p2  c) -> b (h p1) (w p2) c', p1=patch_height, p2=patch_width,
                                 h=image_height // patch_height)

    def forward(self, img):
        x = self.to_patch_embedding(img)
        x = self.transformer(x, H=4, W=4)
        x = self.reshape(x)
        return x


wsm = WSM(image_size=8, patch_size=2, num_classes=2, dim=12, depth=6, heads=4, mlp_dim=32).to(device)

In [None]:
class BiAttn(nn.Module):
    def __init__(self, in_channels, act_ratio=0.5, act_fn=nn.GELU, gate_fn=nn.Sigmoid):
        super().__init__()
        reduce_channels = int(in_channels * act_ratio)
        self.norm = nn.LayerNorm(in_channels)
        self.global_reduce = nn.Linear(in_channels, reduce_channels)
        self.local_reduce = nn.Linear(in_channels, reduce_channels)
        self.act_fn = act_fn()
        self.channel_select = nn.Linear(reduce_channels, in_channels)
        self.spatial_select = nn.Linear(reduce_channels * 2, 1)
        self.gate_fn = gate_fn()


    def forward(self, x):
        ori_x = x

        x = self.norm(x)
        copy = x
        copy = copy.permute(0, 3, 1, 2)
        x_global = F.avg_pool2d(copy, x.shape[2], x.shape[3])
        x_global = x_global.permute(0, 2, 3, 1)
        x_global = self.act_fn(self.global_reduce(x_global))
        x_local = self.act_fn(self.local_reduce(x))

        c_attn = self.channel_select(x_global)
        c_attn = self.gate_fn(c_attn)
        s_attn = self.spatial_select(torch.cat([x_local, x_global.repeat(1, x.shape[1], x.shape[2], 1)], dim=-1))
        s_attn = self.gate_fn(s_attn)
        attn = c_attn * s_attn
        return ori_x * attn


In [None]:
class ConcactFeature(nn.Module):
    def __init__(self, dim=3):
        super(ConcactFeature, self).__init__()
        self.catConv = nn.Conv2d(3, 3, kernel_size=1)
        self.norm1 = nn.LayerNorm([3, patch_size, patch_size])
        self.conv = nn.Conv2d(3, 3, 1)

    def forward(self, x):
        x = self.catConv(x)
        x = self.norm1(x)
        x = self.conv(x)
        return x


class WBANet(nn.Module):
    def __init__(self):
        super(WBANet, self).__init__()
        self.wsm = wsm
        self.bam = BiAttn(in_channels=3)
        self.cf = ConcactFeature()
        self.linear1 = nn.Linear(patch_size * patch_size * 3, 20)
        self.linear2 = nn.Linear(20, 2)
        self.drop = nn.Dropout(0.2)
        self.catConv = nn.Conv2d(6, 3, kernel_size=1)



    def forward(self, img):
        wsmOut = self.wsm(img)
        bamOut = self.bam(img.permute(0, 2, 3, 1))
        catOut = torch.cat((wsmOut, bamOut), 3).permute(0, 3, 1, 2)
        catOut = self.catConv(catOut)

        x = self.cf(catOut)

        out1 = x.view(x.size(0), -1)  # 128 192
        out = self.linear1(out1)
        out = self.linear2(out)
        return out




In [None]:
gamma = 0.75
criterion = nn.CrossEntropyLoss()
model = WBANet().to(device)
lr = 5e-3
epoch = 10
optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = StepLR(optimizer, step_size=1, gamma=gamma)
start_time = time.time()
for epoch in range(epoch):
    epoch_loss = 0
    epoch_accuracy = 0
    for data, label in tqdm(train_loader):
        data = data.to(device)
        label = label.to(device)
        output = model(data)
        loss = criterion(output, label)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        acc = (output.argmax(dim=1) == label).float().mean()
        epoch_accuracy += acc / len(train_loader)
        epoch_loss += loss / len(train_loader)

    print(f"Epoch : {epoch + 1} - loss : {epoch_loss:.4f} - acc: {epoch_accuracy:.4f}\n")



end_time = time.time()


execution_time = end_time - start_time


print(f"train run time: {execution_time:.4f} seconds")

In [None]:

model.eval()
outputs = np.zeros((ylen, xlen))
start_time = time.time()
for i in range(ylen):
    for j in range(xlen):
        if preclassify_lab[i, j] != 1.5:
            outputs[i, j] = preclassify_lab[i, j]
        else:
            img_patch = x_test[i * xlen + j, :, :, :]
            img_patch = img_patch.reshape(1, img_patch.shape[0], img_patch.shape[1], img_patch.shape[2])
            img_patch = torch.FloatTensor(img_patch).to(device)
            prediction = model(img_patch)
            prediction = np.argmax(prediction.detach().cpu().numpy(), axis=1)
            outputs[i, j] = prediction + 1

    if (i + 1) % 50 == 0:
        print('... ... row', i + 1, ' handling ... ...')


outputs = outputs - 1
res = outputs * 255
res = postprocess(res)
evaluate(im_gt, res)
plt.imshow(res, 'gray')
plt.axis('off')  # remove coordinate axis
plt.xticks([])  # remove x axis
plt.yticks([])  # remove y axis

plt.show()

end_time = time.time()

execution_time = end_time - start_time

print(f"test run time: {execution_time:.4f} seconds")



In [None]:
from google.colab import drive
drive.mount('/content/drive')