# Make SPU Dataset

In previous work, we focus on the problem of classifying a given image.

However, in a real problem setting, we are required to classify an SPU, rather than a single image.

An SPU has five or more images. It is a good practice to unify all prediction results of each image.

## Steps

### Overall 

+ data set prepration
    - training set
    - validation set
    - check overlap
+ make prediction
    - rewrite function `makePrediction`
    - voting algorithm, a class
+ model
    - use model from a checking point
    - train from scratch with new data set

### Details

***data set preparation.*** 

1. read `tmall_result.json`, `taobao_result.json` and `garbage_result.json`. filter out outliers.
1. randomly pick 70% SPU for training and 30% SPU for validation for each brand
2. 2% images in validation set as validation in learning curve.
2. ~~~store all these pic in relavant directory.~~~
    ```
    \dataset
        |----train
                |---- 4 sub-dir
        |----val
                |---- 4 sub-dir
        |----vallearn
    ```
3. check whether there is overlap between training set and validation set.
3. the images are saved in the same directory structure as before.

***training***

as usual, still train the classification of single image.    
report the prediction accuracy of individual image in learning curve.

***prediction***

vote algorithm:
+ sum weighted prob 
+ majority voting

1. make prediction for each image in validatio set
2. group images by SPU
3. apply vote algorithm
4. report SPU accuracy.

### Implement

***in the first release*** 
first implement data set preparation, then prediction part.
use checkpoint

***in the second release***
training from scratch.

## Dataset Preparation

In [13]:
#encoding:utf-8

# setting Chinese encoding
# python 2 kernel
import sys
stdi, stdo, stde = sys.stdin, sys.stdout, sys.stderr
reload(sys)
sys.stdin, sys.stdout, sys.stderr = stdi, stdo, stde
sys.setdefaultencoding('utf-8')

In [14]:
import json

metaPath = "json/tmall_result.json"

def readJson(metaPath):
    '''
    read `tmall_result.json` or `taobao_result.json`
    return a list, in which each element is a json record.
    '''
    ret = []
    metafile = open(metaPath, 'r')
    for line in metafile:
        record = json.loads(line[:-2]) # end with ",\n"
        ret.append(record)
    metafile.close()
    return ret
    

def printRawJson(metadata, end, start = 0):
    '''
    metadata is output of `readJson`
    num: how many records to display
    print metadata to stdout
    '''
    # print(metadata)
    for idx, record in enumerate(metadata):
        if idx < end and idx >= start:
            print("record " + str(idx) + ":")
            print("\tid   : " + str(record["_id"]))
            print("\ttitle: " + str(record["title"]))
            print("\timage: ")
            if record.has_key("img"):
                for i in record["img"]:
                    print("\t\t" + i)
            print("\tattrs:")
            if record.has_key("attrs"):
                for i in record["attrs"]:
                    print("\t\t" + i)
            print("")
        
    

printRawJson(readJson(metaPath), 1000)

record 0:
	id   : 573743779841
	title: 小米 米家智能石英手表新款男女情侣潮流简约时尚防水学生手表
	image: 
		3040e4c040054a6887ed5d8f80ea2191.jpg
		2bd39903148b483e94d7a52d74f300d4.jpg
		91b3b528cad449ccb2d186910e11ed89.jpg
		b1882928aebf4bfcb76067b8502801d3.jpg
		234479a861d947eea564cbf96d670d0d.jpg
	attrs:
		保修: 全国联保
		成色: 全新
		手表镜面材质: 镀膜玻璃
		是否商场同款: 否
		机芯产地: 中国
		品牌: MIJIA/米家
		型号: 米家智能石英手表
		机芯类型: 石英机芯
		手表种类: 中性
		风格: 时尚
		表带材质: 真皮
		形状: 圆形
		显示方式: 指针式
		上市时间: 2018年夏季
		颜色分类: 灰色 白色 黑色
		防水深度: 3ATM
		表扣款式: 针扣
		表底类型: 普通
		表冠类型: 普通
		表盘厚度: 11mm
		表盘直径: 40mm
		品牌产地: 国内
		流行元素: 大表盘
		表壳材质: 精钢

record 1:
	id   : 564453267459
	title: Lenovo/联想 笔记本电源适配器65W 细圆口
	image: 
		6b1113a4f24f4bd7bf8021bdc3fea144.jpg
		86556cb2ca7c47d78aff1a251c8f4f0e.jpg
		50d5fc0f41434b0b92423baf8a3f520e.jpg
		fca1c42e2d314add9c6f86b0b3ded4f0.jpg
		6c624d1a19f74afd8cd3fb73e537a900.jpg
	attrs:
		证书编号：2016010907834796
		证书状态：有效
		产品名称：电源适配器
		3C规格型号：ADLX65CDGC2A;  输入：100-240VAC,50-60Hz,1.5A; 输出：20Vd...
		产品名称：Lenovo/联想 65W小细圆口
		生产企业: 台达电子工业

	title: 新品ZMI双模智能充电器+充电宝(5200)移动电源充电宝快充便携迷你
	image: 
		31dae6e6e76942a081414a8aa43562c8.jpg
		65ea1c8911184c6086059ed92b09ccdd.jpg
		cedd423160d6410eb49653f32661e3bf.jpg
		7e04e7478ee042f9afe195a56700e3c6.jpg
		270425944ac04210b80481dfa89e9a49.jpg
	attrs:
		产品名称：ZMI ZMI双模智能充电器+充...
		品牌: ZMI
		型号: ZMI双模智能充电器+充电宝(5200)
		外壳材质: 塑料
		尺寸: 72x71x32mm
		生产企业: 江苏紫米电子技术有限公司
		电池容量: 5000mAh
		附加功能: 多U口输出
		颜色分类: 黑色
		电池类型: 锂离子电池

record 124:
	id   : 570008424731
	title: 小米 米家投影仪高清智能投影仪家用  1080P 分辨率
	image: 
		1e8c383a4b0f4d76a711810e26d9f6a0.jpg
		188db32399644af896d230dcd70b7ff2.jpg
		ea91db9d2da14c3282b92989877d33e8.jpg
		7e14178262504d8c8b6f204dc3d106a3.jpg
		2f048b6ac1554bed971ae02bb8a288f0.jpg
	attrs:
		证书编号：2018010903048991
		证书状态：有效
		产品名称：米家投影仪
		3C规格型号：TYY01ZM：220V～ 50Hz 1A
		平台类型: 无
		影片效果: 3D
		投放画面大小: 50英寸~250英寸
		操作系统: 安卓
		支持色彩数目: 10.7亿色
		是否可吊装: 是
		显示技术: DLP
		最佳投放距离: 3米以内
		机体尺寸（cm）: 21.5*11.6*21.8
		梯形校正范围: ±40度
		梯形矫正: 垂直 左右
		灯泡功率: 50W
		灯泡寿命: 30000小时
		片源内容: 优酷 其他 爱奇艺
		生产企

		生产企业: 小米
		颜色分类: 13.3寸灰色 13.3寸蓝色 12.5寸蓝色 12.5寸灰色
		尺寸: 其它尺寸
		风格: 简约
		品牌: Xiaomi/小米

record 245:
	id   : 544536060488
	title: 格力 干衣机GSP20干衣机家用衣服烘干机滚筒烘衣机静音省电大功率
	image: 
		5e6a8077743245c3b17edbe58a81a3fa.jpg
		d9dfed5ddd274282b49a7c3549456862.jpg
		dbb5d46d021a4e82a911151c6921fda2.jpg
		ffd552d12d77413db4ec0010854d29e4.jpg
		9bd56775f690437ea42d21c3e5479cf8.jpg
	attrs:
		产品名称：Gree/格力 GSP20
		保修期: 12个月
		品牌: Gree/格力
		型号: GSP20
		生产企业: 珠海格力电器股份有限公司
		货号: GSP20
		颜色分类: 干衣容量二公斤
		功率: 501W(含)-1000W(含)
		形状: 立方形
		承载重量: 5kg以下
		最大定时范围: 61分钟(含)-120分钟(含)
		有无脚轮: 无
		是否有断电保护功能: 有
		加热方式: PTC
		适用场景: 家用
		支架材质: 其他/other
		控制方式: 机械式

record 246:
	id   : 520122751561
	title: 小米 1MORE/万魔 1more活塞耳机入耳式 音乐男女运动耳麦苹果通用
	image: 
	attrs:
		产品名称：1MORE/万魔 1more活塞耳机...
		兼容平台: ANDROID iOS
		套餐类型: 官方标配
		灵敏度: 97dB/mW
		适用音乐类型: 人声女声流行风类型
		阻抗: 32Ω
		频响范围: 20-20000Hz
		颜色分类: 星空钛 黑色
		佩戴方式: 入耳式
		耳机类型: 有线
		有无麦克风: 带麦
		插头直径: 3.5mm
		耳机插头类型: 直插型
		耳机输出音源: 随身视听
		缆线长度: 1.2M
		耳机类别: 普通耳机 手机线控耳机 HIFI耳机
		品牌: 1MOR

		ec3ad4da8c2e49a2881c678e87b66b36.jpg
	attrs:
		证书编号：2018010808053576
		证书状态：有效
		产品名称：小米电视
		3C规格型号：L32M5-AD ：220V～ 50/60Hz　50W
		产品名称：Xiaomi/小米 小米电视4C ...
		分辨率: 1366x768
		3D类型: 无
		能效等级: 三级
		网络连接方式: 全部支持
		操作系统: MIUI TV版
		品牌: Xiaomi/小米
		型号: 小米电视4C 32英寸

record 420:
	id   : 570483891218
	title: 小米净水器1A滤芯3合一复合滤芯家用龙头过滤芯
	image: 
		ca243c5b4481424ba114c2a44a605812.jpg
		4d12bbfd68dc440887c3d7923357afb8.jpg
		e45d90a968a546199dd5a5de376bca3f.jpg
		cccb4d21f13e4327a53e4e3a84db2859.jpg
		2ebec5a1c9a1435e8e08b686eecd2793.jpg
	attrs:
		品牌: Xiaomi/小米
		货号: 小米净水器1A滤芯

record 421:
	id   : 544549389333
	title: 格力除湿机家用抽湿机静音除湿机dh20eh干燥机吸湿器一机多用
	image: 
		e89bbe260e104eadbee321d04d133c7b.jpg
		bfd5e9a53b0e4da98c9e26b5230148c8.jpg
		ca15cd5f174a4564b0501db5e4124f30.jpg
		2e2aef6850bb4f6c82127ba958397e2c.jpg
		286846be2f314d79b23e95f075057cf5.jpg
	attrs:
		产品名称：Gree/格力 DH20EH
		保修期: 6年
		功率: 330W
		品牌: Gree/格力
		格力抽湿机型号: DH20EH
		智能类型: 其他
		最大日除湿量: 0.83L/h
		水箱容量: 6L
		生产企业: 珠海格力电器股份有限公司
		适用面积:

		申请人名称：珠海格力电器股份有限公司
		制造商名称：珠海格力电器股份有限公司
		产品名称：IH智能电饭煲
		3C产品型号：GDCF-5006C, GDCF-5005C, GDCF-5004C,  5.0L;  GDCF-4...
		3C规格型号：GDCF-50X60Ca   GDCF-5006C, GDCF-5005C, GDCF-5004C,...
		产品名称：TOSOT/大松 GDCF-40X60C...
		品牌: TOSOT/大松
		型号: GDCF-40X60CA
		颜色分类: 金色
		容量: 4L
		控制方式: 微电脑式
		电饭煲多功能: 煲仔饭 蛋糕 预约 定时 煮饭
		内胆材质: 合金
		售后服务: 全国联保
		形状: 方形
		加热方式: IH加热
		适用人数: 5人-6人

record 576:
	id   : 564084131245
	title: Gree/格力取暖器NTFD-X6020B电暖器家用定时摇头静音LED触摸屏
	image: 
		9de576f8318d422fbf232cff29e6b066.jpg
		f239834a7d4248c9954eba13353b1145.jpg
		fb21091ab7f94193b6b6694a9964605b.jpg
		aea50eac123042e7ad68a116bc631925.jpg
		8584bb174e514e95a8e85efb3f6fa462.jpg
	attrs:
		证书编号：2015010707828382
		证书状态：有效
		产品名称：电暖器
		3C规格型号：NTFD-X6020B NTFD-20B 220V～ 50Hz 2000W
		产品名称：Gree/格力 NTFD-X6020B
		保修期: 12个月
		取暖器加热方式: 陶瓷加热
		品牌: Gree/格力
		格力取暖器型号: NTFD-X6020B
		智能类型: 不支持智能
		最大采暖面积(平方米): 20m^2及以上
		生产企业: 珠海格力电器股份有限公司
		电暖器最大功率: 1200W(含)-2000W(含)
		采购地: 中国大陆
		颜色分类: 金褐色
		适用面积: 11m^2 (含)-20m^2 (含)
		档位: 3档

record

		适用面积: 21m^2 (含)-30m^2 (含)
		加湿方式: 出雾

record 723:
	id   : 573841866131
	title: 小米小黄鸡米兔 可爱Q萌减压治愈系cosplay米粉毛绒玩具公仔
	image: 
		8b64fd9d97b548108c8e80a5c54dacb5.jpg
		205576efed1440ac87fee8b3a7a04998.jpg
		d8c8f5ddcee04b8585b384e55cf2ece2.jpg
		f148ec428fd0402ca070c42ee4507180.jpg
		267ba6502c4b4a5a92aa4990afcf30e9.jpg
	attrs:
		品牌: Xiaomi/小米
		货号: 小黄鸡米兔
		玩偶种类: 兔兔
		适用年龄: 3岁 4岁 5岁 6岁 7岁 8岁 9岁 10岁 11岁 12岁 13岁 14岁 14岁以上
		玩具类型: 公仔
		颜色分类: 浅黄色
		材质: 毛绒
		填充物: PP棉
		高度: 25厘米

record 724:
	id   : 543730526000
	title: Gree/格力 KFR-50LW/(505511)NhAaD-3 大2匹定频冷暖客厅空调立式
	image: 
		ebd816a10f5046848a6188826cb06af0.jpg
		47402d6e77db48a2aa03945b217187eb.jpg
		40ab40508d8142c4ae6489db46c452a5.jpg
		2f980a00765d419cb2675f1e26beea57.jpg
		fde66a9a8161479e88d41e9114575a33.jpg
	attrs:
		证书编号：2016010703907107
		证书状态：有效
		申请人名称：珠海格力电器股份有限公司
		制造商名称：珠海格力电器股份有限公司
		产品名称：空调器（分体热泵型落地式房间空调器）
		3C产品型号：整机：KFR-50LW/(505511)NhAaD-3（室内机：KFR-50L(505511)NhA...
		3C规格型号：整机：KFR-50LW/(505511)NhAaD-3（室内机：KFR-50L(505511)NhA.

In [15]:
import copy

idBrandCateFile = "json/id2cate_brand_tmall.json" # id, brand name, category id list
cateFile = "json/merge_c_tmall"  # category id, category name

def substituteIdOneRecord(record, idBrandCateList, cateList):
    '''
    record is one record in json file (output of `readJson`)
    output is a record with brand and category name.
    '''
    # get id
    itemid = record["_id"]
    ##  # file rewind # too slow
    ##  idBrandCate.seek(0)
    ##  cate.seek(0) 
    # make category id set
    cate = dict()
    for line in cateList:
        name, cid = line.strip().split(",")
        cate[cid] = name
    # main part
    brand = ""
    categoryList = []
    for line in idBrandCateList:
        tmp = json.loads(line[:-2])
        # find id in id-brand-category file
        if tmp["_id"] == record["_id"]:
            # get brand and category id list
            brand = tmp["brand"]
            categoryList = tmp["category"]
    # find category names by category id
    # construct new record: original record + brand + [ category id, category name ]
    ret = copy.deepcopy(record)
    ret['brand'] = brand
    tmpList = []
    for i in categoryList:
        tmpList.append(i + "|" + cate[i])
    ret['category'] = tmpList
    # for c in categoryList:
    #     record[]
    
    # return
    return ret

def substituteIdFull(metadata, idBrandCateFile, cateFile): 
    '''
    metadata is the output of `readJson`
    '''
    ret = []
    # read info
    # since file operation (seek(0)) is too slow, read file to buffer at first in one shot
    idBrandCate = open(idBrandCateFile, 'r')
    cate = open(cateFile, 'r')
    idBrandCateList = []
    cateList = []
    for line in idBrandCate:
        idBrandCateList.append(line)
    for line in cate:
        cateList.append(line)
    # main part
    for record in metadata:
        ret.append( substituteIdOneRecord(record, idBrandCateList, cateList) )
    # close files
    idBrandCate.close()
    cate.close()
    return ret
    
def printFullJson(metadata, num):
    '''
    metadata is the output of `substituteIdFull`
    num: how many records to display
    print metadata to stdout
    '''
    # print(metadata)
    for idx, record in enumerate(metadata):
        if idx < num:
            print("record " + str(idx) + ":")
            print("\tid   : " + str(record["_id"]))
            print("\tbrand: " + record["brand"])
            print("\ttitle: " + str(record["title"]))
            print("\timage: ")
            if record.has_key("img"):
                for i in record["img"]:
                    print("\t\t" + i)
            print("\tcateg: ")
            if record.has_key("category"):
                for i in record["category"]:
                    print("\t\t" + i)
            print("\tattrs:")
            if record.has_key("attrs"):
                for i in record["attrs"]:
                    print("\t\t" + i)
            print("----------------------------------------------")
        

fullmetadata = substituteIdFull(readJson(metaPath), idBrandCateFile, cateFile)
printFullJson(fullmetadata, 10)

record 0:
	id   : 573743779841
	brand: xiaomi
	title: 小米 米家智能石英手表新款男女情侣潮流简约时尚防水学生手表
	image: 
		3040e4c040054a6887ed5d8f80ea2191.jpg
		2bd39903148b483e94d7a52d74f300d4.jpg
		91b3b528cad449ccb2d186910e11ed89.jpg
		b1882928aebf4bfcb76067b8502801d3.jpg
		234479a861d947eea564cbf96d670d0d.jpg
	categ: 
		807818185|箱包鞋服及生活周边:服饰
	attrs:
		保修: 全国联保
		成色: 全新
		手表镜面材质: 镀膜玻璃
		是否商场同款: 否
		机芯产地: 中国
		品牌: MIJIA/米家
		型号: 米家智能石英手表
		机芯类型: 石英机芯
		手表种类: 中性
		风格: 时尚
		表带材质: 真皮
		形状: 圆形
		显示方式: 指针式
		上市时间: 2018年夏季
		颜色分类: 灰色 白色 黑色
		防水深度: 3ATM
		表扣款式: 针扣
		表底类型: 普通
		表冠类型: 普通
		表盘厚度: 11mm
		表盘直径: 40mm
		品牌产地: 国内
		流行元素: 大表盘
		表壳材质: 精钢
----------------------------------------------
record 1:
	id   : 564453267459
	brand: lenovo
	title: Lenovo/联想 笔记本电源适配器65W 细圆口
	image: 
		6b1113a4f24f4bd7bf8021bdc3fea144.jpg
		86556cb2ca7c47d78aff1a251c8f4f0e.jpg
		50d5fc0f41434b0b92423baf8a3f520e.jpg
		fca1c42e2d314add9c6f86b0b3ded4f0.jpg
		6c624d1a19f74afd8cd3fb73e537a900.jpg
	categ: 
		1371242209|智能选件与服务:电脑配件
	attrs:
		证书

In [23]:
def getBrand(brandList):
    '''
    brandList: is the path of `label.dat`
    return a dictionary: {brand, idx}
    ''' 
    labelDict = dict()
    labelFile = open(brandList, 'r')
    for line in labelFile:
        idx, brand = line.strip().split(" ")
        labelDict[brand] = int(idx)
    labelFile.close() 
    return labelDict

def groupTmallByBrandList(metaPath, brandList, idBrandCateFile, cateFile):
    '''
    get brand list for SPUs in `tmall_result.json`
    return: [[gree], [xiaomi], [lenovo], ...]. list of list of json object
            the order in return is the same as that in `label.dat`
    NOTICE: there is no category `other` in the output of this function
    metaPath is the path of `tmall_result.json`
    brandList: is the path of `label.dat`
    idBrandCateFile: is the path of `id, brand name, category ids`
    cateFile: is the path of `category id, category name`
    all the thing relevant to category is not important at all here.
    '''
    # get full record. fullmeta is list of json object
    fullmeta = substituteIdFull(readJson(metaPath), idBrandCateFile, cateFile)
    # get all labels
    labelDict = getBrand(brandList)
    num = len(labelDict) - 1 # no `other` category in this function
    # return list
    ret = [[]]
    for i in range(num-1):
        ret.append([])
    # loop the fullmeta
    for record in fullmeta:
        # apply filter: remove missing-img, missing-brand SPUs
        if record.has_key("img") and record['brand'] != "":
            brandIdx = labelDict[record['brand']]
            ret[brandIdx].append(record)
    return ret
    
    
groupMeta = groupTmallByBrandList(metaPath="json/tmall_result.json", brandList="json/label.dat", idBrandCateFile="json/id2cate_brand_tmall.json", cateFile="json/merge_c_tmall")
# getBrand(brandList="json/label.dat")

In [32]:
len(groupMeta[0])+len(groupMeta[1])+len(groupMeta[2])

791

In [36]:
def garbageList(garbagePath):
    '''
    garbagePath is "json/garbage_result.json".
    return: [json]. a list of json.
    '''
    # get garbage metadata. 
    # we only care about `_id` and `img` currently. in the future, we may also interest in `title` and `attrs`
    garbageMeta = readJson(garbagePath)
    ret = []
    for record in garbageMeta:
        if record.has_key("img"):
            ret.append(record)
    return ret
            
len(garbageList(garbagePath="json/garbage_result.json"))

1998

In [41]:
import numpy as np

def randomChoiceBoolean(totalnum, num):
    '''
    choose num from totalnum randomly.
    return a boolean list.
    the length of list = totalnum. 
    the number of True = num.
    '''
    ret = np.repeat(False, totalnum)
    pick = np.random.choice(np.arange(totalnum), num, replace = False)
    for i in pick:
        ret[i] = True
    return ret

def randomSelectFixNumber(collection, num):
    '''
    random select `num` record in collection.
    collection: is a list. any list is OK.
    num: the num of record selected
    return: a list. subset-collection, whose length = num
    '''
    totalnum = len(collection)
    pick = randomChoiceBoolean(totalnum, num)
    # print(np.where(pick==True))
    return list(np.array(collection)[pick])
    
# randomChoiceBoolean(10, 3)
# assume each SPU has 5 images. 2000 images / 5 = 400 SPU
# in garbage_result.json, there are 1998 SPU, select 400 among them
garbageMeta = randomSelectFixNumber(garbageList(garbagePath="json/garbage_result.json"), 400)

In [44]:
def groupFullByBrandList(garbageMeta, groupMeta):
    '''
    groupMeta is output of groupTmallByBrandList. [[gree jsons], [xiaomi jsons], [lenovo jsons], ...]
    garbageMeta is output of `randomSelectFixNumber(garbageList, 400)`. [[other jsons]]
    return: [[gree jsons], [xiaomi jsons], [lenovo jsons], ..., [other jsons]]
    len(return) = number of lines in label.dat
    '''
    ret = copy.deepcopy(groupMeta)
    ret.append(garbageMeta)
    return ret

# final metadata list. training set and validation set will be chosed randomly from it
groupMetaFull = groupFullByBrandList(garbageMeta, groupMeta)

In [45]:
len(groupMetaFull)

4

In [48]:
def getFullMetaStatistic(groupCollection):
    '''
    groupCollection is output of `groupTmallByBrandList` or `groupFullByBrandList`
    get statistic info of groupCollection.
    return: spu_list lenght, img_list length
    '''
    spu_list = []
    img_list = []
    for brandCollection in groupCollection:
        spu_list.append(len(brandCollection))
        nimg = 0
        for record in brandCollection:
            nimg += len(record['img'])
        img_list.append(nimg)
    return spu_list, img_list
        
print(getFullMetaStatistic(groupMeta))
# print(getFullMetaStatistic(garbageMeta)) # it is wrong!
print(getFullMetaStatistic(groupMetaFull))


([214, 421, 156], [1038, 1970, 759])
([214, 421, 156, 400], [1038, 1970, 759, 1983])


In [52]:
def randomPickSPUTmallGarbarge(groupMetaFull, outputDir, percentage = 0.7):
    '''
    write two files:
    1. [SPU_id, image_name, brand] for training set
    2. [SPU_id, image_name, brand] for validation set
    we only care about `_id` and `img` currently. in the future, we may also interest in `title` and `attrs`
    
    groupMetaFull: output of groupFullByBrandList 
    there are len(groupMetaFull)-1 brands and one `other` category
    for each category: 
    percentage*(SPU in this brand) are included in training set.
    (1-percentage)*(SPU in this brand) are included in validation set.
    no return.
    
    
    '''
    trainfileName = outputDir + "train.txt"
    valfileName = outputDir + "val.txt"
    trainfile = open(trainfileName, 'w')
    valfile = open(valfileName, 'w')
    numbrand = len(groupMetaFull)
    for idx, brand in enumerate(groupMetaFull):
        totalnum = len(brand)
        num = int(totalnum * percentage)
        mask = randomChoiceBoolean(totalnum, num)
        train = np.array(brand)[mask]  # SPU for training 
        val = np.array(brand)[np.where(mask==False)] # SPU for validation
        for record in train:
            for img in record['img']:
                if idx != numbrand - 1:
                    imgfull = "../tmall_pic/" + img
                else:
                    imgfull = "../garbage_pic/new_pic/" + img
                string = str(record["_id"]) + " " + imgfull + " " + str(idx) + "\n"
                trainfile.write(string)
        for record in val:
            for img in record['img']:
                if idx != numbrand - 1:
                    imgfull = "../tmall_pic/" + img
                else:
                    imgfull = "../garbage_pic/new_pic/" + img
                string = str(record["_id"]) + " " + imgfull + " " + str(idx) + "\n"
                valfile.write(string)
    trainfile.close()
    valfile.close()

    
# randomPickSPUTmallGarbarge(groupMetaFull, "json/", 0.7)

In [53]:
def checkOverlap(trainFile, valFile):
    '''
    if output = [], there is no overlap between training SPU/image and validation SPU/image
    '''
    ret = []
    train = open(trainFile)
    trainset = set()
    for line in train:
        spuid, img, cat = line.strip().split(" ")
        trainset.add(img)
    train.close()
    val = open(valFile)
    for line in val:
        spuid, img, cat = line.strip().split(" ")
        if img in trainset:
            ret.append(img)
    return ret
        
checkOverlap('json/train.txt', 'json/val.txt') 

[]

In [None]:
# def trainvalNLP(trainFile, valFile)
    '''
    read SPU id from train.txt and val.txt
    TODO: output [SPU_id, title, attr, brand] to another (two) files.
    '''